diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 6bc6cdc4d41a9951b9fecfdd8871300d87466147..161927a49c2dcb71aef20c369e9293ca2b24eb57 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -38,7 +38,7 @@ env = RailEnv(width=50, seed=0, # Random seed grid_mode=False, max_inter_city_rails=2, - tracks_in_city=50, + tracks_in_city=5, ), schedule_generator=sparse_schedule_generator(), number_of_agents=50, diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index a4d34a6bfe207f36077bedf3dec578e499cb703b..7bf8a978cf1f55e256433cbd38f588376b43c530 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -558,7 +558,9 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, rail_array = grid_map.grid rail_array.fill(0) np.random.seed(seed + num_resets) - + max_inter_city_rails_allowed = max_inter_city_rails + if max_inter_city_rails_allowed > tracks_in_city: + max_inter_city_rails_allowed = tracks_in_city # Generate a set of nodes for the sparse network # Try to connect cities to nodes first city_positions = [] @@ -581,7 +583,7 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, # Connect the cities through the connection points outer_connection_points = _connect_cities(node_positions, connection_points, connection_info, city_cells, - max_inter_city_rails, + max_inter_city_rails_allowed, rail_trans, grid_map) # Build inner cities @@ -705,7 +707,7 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, connection_info.append(connections_per_direction) return connection_points, connection_info - def _connect_cities(node_positions, connection_points, connection_info, city_cells, max_inter_city_rails, + def _connect_cities(node_positions, connection_points, connection_info, city_cells, max_inter_city_rails_allowed, rail_trans, grid_map): """ Function to connect the different cities through their connection points @@ -739,7 +741,7 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, neighb_idx = neighbours[i] connected_to_city.append(neighb_idx) - number_of_out_rails = np.random.randint(1, max_inter_city_rails + 1) + number_of_out_rails = np.random.randint(1, max_inter_city_rails_allowed + 1) for tmp_out_connection_point in connection_points[current_node][direction][:number_of_out_rails]: # Find closest connection point