From 094af7fe1a62c4f5f5b0d4e5de7bda8237fc9748 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Tue, 1 Oct 2019 08:59:14 -0400 Subject: [PATCH] minor refactoring of city connection code --- flatland/envs/rail_generators.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 8880d27b..ba12736f 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -744,25 +744,16 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, all_paths = [] for current_node in np.arange(len(node_positions)): - direction = 0 - connected_to_city = [] neighbours = _closest_neigh_in_direction(current_node, node_positions) - for nbr_connection_points in connection_info[current_node]: - if nbr_connection_points > 0: - neighb_idx = neighbours[direction] - else: - direction += 1 - continue - - # If no closest neighbour was found look at the neighbouring connections - tmp_direction = (direction - 1) % 4 - while neighb_idx is None: - neighb_idx = neighbours[tmp_direction] - tmp_direction = (direction + 1) % 4 - - connected_to_city.append(neighb_idx) - for tmp_out_connection_point in connection_points[current_node][direction]: - # Find closest connection point + for out_direction in range(4): + for tmp_out_connection_point in connection_points[current_node][out_direction]: + # This only needs to be checked when entering this loop + neighb_idx = neighbours[out_direction] + if neighb_idx is None: + tmp_direction = (out_direction - 1) % 4 + while neighb_idx is None: + neighb_idx = neighbours[tmp_direction] + tmp_direction = (out_direction + 1) % 4 min_connection_dist = np.inf for dir in range(4): current_points = connection_points[neighb_idx][dir] @@ -776,13 +767,11 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, city_cells) - G.add_edge(current_node, neighb_idx, direction=direction, length=len(new_line)) + G.add_edge(current_node, neighb_idx, direction=out_direction, length=len(new_line)) G.add_edge(neighb_idx, current_node, direction=neighbour_direction, length=len(new_line)) all_paths.extend(new_line) - direction += 1 - return all_paths def _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, node_radius, rail_trans, -- GitLab