diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 69959ca8c056465260afc987b1100d0cfb316285..ab730a0a8d461fc72739fe19bd4ec7c2ad1b992e 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -37,8 +37,8 @@ env = RailEnv(width=50, min_node_dist=15, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center seed=15, # Random seed - grid_mode=True, - connection_points_per_side=3, + grid_mode=False, + max_connection_points_per_side=2, max_nr_connection_directions=4, ), schedule_generator=sparse_schedule_generator(), diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 178da0c71cbe06656e2002f7ff7dfba36eba4c58..f7920c29729b30153d153ce55d4ce45ea0894082 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -533,7 +533,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, node_radius=2, - grid_mode=False, connection_points_per_side=4, + grid_mode=False, max_connection_points_per_side=4, max_nr_connection_directions=2, seed=0) -> RailGenerator: """ @@ -598,7 +598,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n # Set up connection points for all cities connection_points, connection_info = _generate_node_connection_points(node_positions, node_radius, - connection_points_per_side, + max_connection_points_per_side, max_nr_connection_directions) # Connect the cities through the connection points @@ -821,7 +821,11 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n # Connect train station to random nodes - rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2, replace=False) + if len(connection_points[trainstation_node]) > 1: + rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2, + replace=False) + else: + rand_corner_nodes = [0] for corner_node_idx in rand_corner_nodes: connection = connect_nodes(rail_trans, grid_map,