diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 96fb87f8d7043e17b8b2f3291f76b0440af463eb..15b19f8cce63792a537d6b140dc16c8c0f9bbe90 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -12,24 +12,25 @@ np.random.seed(1) # Training on simple small tasks is the best way to get familiar with the environment # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.5, - 'malfunction_rate': 30, - 'min_duration': 3, - 'max_duration': 10} +stochastic_data = {'prop_malfunction': 0.5, # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence + 'min_duration': 3, # Minimal duration of malfunction + 'max_duration': 10 # Max duration of malfunction + } TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) -env = RailEnv(width=50, - height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map - num_intersections=3, # Number of interesections in map - num_trainstations=40, # Number of possible start/targets on map - min_node_dist=10, # Minimal distance of nodes +env = RailEnv(width=10, + height=10, + rail_generator=sparse_rail_generator(num_cities=3, # Number of cities in map + num_intersections=1, # Number of interesections in map + num_trainstations=8, # Number of possible start/targets on map + min_node_dist=3, # Minimal distance of nodes node_radius=2, # Proximity of stations to city center - num_neighb=4, # Number of connections to other cities + num_neighb=2, # Number of connections to other cities seed=15, # Random seed ), - number_of_agents=10, + number_of_agents=5, stochastic_data=stochastic_data, # Malfunction generator data obs_builder_object=TreeObservation) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 0b7ea708c27bb4cccef07cc71ef17d5759315952..2aa148999d2398ac6ab0696388f06bf4e93b8fab 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -875,7 +875,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation # Set number of neighboring nodes if len(available_nodes) >= num_neighb: connected_neighb_idx = available_nodes[ - 0:2] # np.random.choice(available_nodes, num_neighb, replace=False) + 0:num_neighb] # np.random.choice(available_nodes, num_neighb, replace=False) else: connected_neighb_idx = available_nodes @@ -885,10 +885,8 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation node_stack.append(neighb) connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) node_stack.pop(0) - # Place train stations close to the node # We currently place them uniformly distirbuted among all cities - train_stations = [[] for i in range(num_cities)] for station in range(num_trainstations): @@ -940,6 +938,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation tries += 1 # Test again with new start node if no pair is found (This code needs to be improved) if tries > 10: + break start_node = np.random.choice(avail_start_nodes) node_available_start[start_node] -= 1