diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index b3c24da79a3f16480405dfe5b7bf94902d513a0e..ceedc90a95f8a434be67a7533873d3eb00154537 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -21,14 +21,15 @@ stochastic_data = {'prop_malfunction': 0.5, # Percentage of defective agents TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) env = RailEnv(width=20, height=20, - rail_generator=sparse_rail_generator(num_cities=5, # Number of cities in map (where train stations are) - num_intersections=4, # Number of interesections (no start / target) + rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are) + num_intersections=1, # Number of interesections (no start / target) num_trainstations=15, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes node_radius=3, # Proximity of stations to city center - num_neighb=3, # Number of connections to other cities/intersections + num_neighb=2, # Number of connections to other cities/intersections seed=15, # Random seed - realistic_mode=True + realistic_mode=True, + enhance_intersection=True ), number_of_agents=5, stochastic_data=stochastic_data, # Malfunction generator data diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 087c6299966f5d1e04e6d21809afbd73999be6c1..e3b59e4fccba235b008eb8c576f5bdb334061248 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -975,7 +975,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if num_agents > num_trainstations: num_agents = num_trainstations - warnings.warn("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") + warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) @@ -998,6 +998,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) + for node_idx in range(num_cities + num_intersections): to_close = True tries = 0 @@ -1043,18 +1044,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 available_cities = np.arange(num_cities) available_intersections = np.arange(num_cities, num_cities + num_intersections) - # Keep track of number of incoming connection - incoming_connections = np.zeros(num_intersections + num_cities) - # Start at some node current_node = np.random.randint(len(available_nodes_full)) node_stack = [current_node] allowed_connections = num_neighb + first_node = True while len(node_stack) > 0: current_node = node_stack[0] delete_idx = np.where(available_nodes_full == current_node) available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) - filled_nodes = np.where(incoming_connections >= num_neighb) # Priority city to intersection connections if current_node < num_cities and len(available_intersections) > 0: available_nodes = available_intersections @@ -1083,16 +1081,16 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 else: connected_neighb_idx = available_nodes - if current_node == 0: + # Less connections for subsequent nodes + if first_node: allowed_connections -= 1 + first_node = False # Connect to the neighboring nodes for neighb in connected_neighb_idx: if neighb not in node_stack: node_stack.append(neighb) connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) - incoming_connections[neighb] += 1 - incoming_connections[current_node] += 1 node_stack.pop(0) # Place train stations close to the node @@ -1133,6 +1131,12 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if len(connection) == 0: train_stations[trainstation_node].pop(-1) + # Adjust the number of agents if you could not build enough trainstations + built_num_trainstation = len(train_stations) + if num_agents > built_num_trainstation: + num_agents = built_num_trainstation + warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") + # Place passing lanes at intersections # We currently place them uniformly distirbuted among all cities if enhance_intersection: