From 555badf440fa401587b25227760b9178de85e7eb Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Thu, 26 Sep 2019 17:24:13 -0400 Subject: [PATCH] fixed index error in city connection algorithm --- examples/flatland_2_0_example.py | 2 +- flatland/envs/rail_generators.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 9cab89e8..12f84422 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -39,7 +39,7 @@ env = RailEnv(width=50, num_neighb=3, # Number of connections to other cities/intersections seed=15, # Random seed grid_mode=True, - nr_parallel_tracks=10, + nr_parallel_tracks=1, connection_points_per_side=2, max_nr_connection_directions=4, ), diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index e72271ad..ee218dc6 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -742,14 +742,15 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n :return: """ boarder_connections = set() - # Start at some node for current_node in np.arange(len(node_positions)): direction = 0 for nbr_connection_points in connection_info[current_node]: if nbr_connection_points > 0: neighb_idx = _closest_neigh_in_direction(current_node, direction, node_positions) - print(current_node, direction, neighb_idx, connection_info[current_node]) + print(current_node, node_positions[current_node], direction, neighb_idx, + connection_info[current_node]) else: + direction += 1 continue if neighb_idx is not None: @@ -771,6 +772,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n boarder_connections.add((tmp_out_connection_point, current_node)) boarder_connections.add((neighb_connection_point, neighb_idx)) direction += 1 + return boarder_connections def _build_cities(node_positions, connection_points, rail_trans, grid_map): @@ -923,7 +925,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n return neighb if direction == 3: - if node_positions[neighb][0] < node_positions[current_node][0] and distance_0 <= distance_1: + if node_positions[neighb][1] < node_positions[current_node][1] and distance_0 <= distance_1: return neighb return None -- GitLab