From 80d4a184354d5668695c7e01c8886d67058f67c0 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Thu, 26 Sep 2019 18:00:30 -0400 Subject: [PATCH] fixed errors in random node distribution --- examples/flatland_2_0_example.py | 4 ++-- flatland/envs/rail_generators.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 69959ca8..ab730a0a 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -37,8 +37,8 @@ env = RailEnv(width=50, min_node_dist=15, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center seed=15, # Random seed - grid_mode=True, - connection_points_per_side=3, + grid_mode=False, + max_connection_points_per_side=2, max_nr_connection_directions=4, ), schedule_generator=sparse_schedule_generator(), diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 178da0c7..f7920c29 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -533,7 +533,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, node_radius=2, - grid_mode=False, connection_points_per_side=4, + grid_mode=False, max_connection_points_per_side=4, max_nr_connection_directions=2, seed=0) -> RailGenerator: """ @@ -598,7 +598,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n # Set up connection points for all cities connection_points, connection_info = _generate_node_connection_points(node_positions, node_radius, - connection_points_per_side, + max_connection_points_per_side, max_nr_connection_directions) # Connect the cities through the connection points @@ -821,7 +821,11 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n # Connect train station to random nodes - rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2, replace=False) + if len(connection_points[trainstation_node]) > 1: + rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2, + replace=False) + else: + rand_corner_nodes = [0] for corner_node_idx in rand_corner_nodes: connection = connect_nodes(rail_trans, grid_map, -- GitLab