From f33ca15732b40ca23ef4f2ea8b372e5af3c1bf9c Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Fri, 27 Sep 2019 18:13:12 -0400 Subject: [PATCH] limiting number of connections between cities not to be more then number of tracks internally --- examples/flatland_2_0_example.py | 2 +- flatland/envs/rail_generators.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 6bc6cdc4..161927a4 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -38,7 +38,7 @@ env = RailEnv(width=50, seed=0, # Random seed grid_mode=False, max_inter_city_rails=2, - tracks_in_city=50, + tracks_in_city=5, ), schedule_generator=sparse_schedule_generator(), number_of_agents=50, diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index a4d34a6b..7bf8a978 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -558,7 +558,9 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, rail_array = grid_map.grid rail_array.fill(0) np.random.seed(seed + num_resets) - + max_inter_city_rails_allowed = max_inter_city_rails + if max_inter_city_rails_allowed > tracks_in_city: + max_inter_city_rails_allowed = tracks_in_city # Generate a set of nodes for the sparse network # Try to connect cities to nodes first city_positions = [] @@ -581,7 +583,7 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, # Connect the cities through the connection points outer_connection_points = _connect_cities(node_positions, connection_points, connection_info, city_cells, - max_inter_city_rails, + max_inter_city_rails_allowed, rail_trans, grid_map) # Build inner cities @@ -705,7 +707,7 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, connection_info.append(connections_per_direction) return connection_points, connection_info - def _connect_cities(node_positions, connection_points, connection_info, city_cells, max_inter_city_rails, + def _connect_cities(node_positions, connection_points, connection_info, city_cells, max_inter_city_rails_allowed, rail_trans, grid_map): """ Function to connect the different cities through their connection points @@ -739,7 +741,7 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, neighb_idx = neighbours[i] connected_to_city.append(neighb_idx) - number_of_out_rails = np.random.randint(1, max_inter_city_rails + 1) + number_of_out_rails = np.random.randint(1, max_inter_city_rails_allowed + 1) for tmp_out_connection_point in connection_points[current_node][direction][:number_of_out_rails]: # Find closest connection point -- GitLab