diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 8880d27b852f7d2d54d93ffd23e42156bced0b74..ba12736f6d3ad241f6b323bf1ef7204b036e3668 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -744,25 +744,16 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, all_paths = [] for current_node in np.arange(len(node_positions)): - direction = 0 - connected_to_city = [] neighbours = _closest_neigh_in_direction(current_node, node_positions) - for nbr_connection_points in connection_info[current_node]: - if nbr_connection_points > 0: - neighb_idx = neighbours[direction] - else: - direction += 1 - continue - - # If no closest neighbour was found look at the neighbouring connections - tmp_direction = (direction - 1) % 4 - while neighb_idx is None: - neighb_idx = neighbours[tmp_direction] - tmp_direction = (direction + 1) % 4 - - connected_to_city.append(neighb_idx) - for tmp_out_connection_point in connection_points[current_node][direction]: - # Find closest connection point + for out_direction in range(4): + for tmp_out_connection_point in connection_points[current_node][out_direction]: + # This only needs to be checked when entering this loop + neighb_idx = neighbours[out_direction] + if neighb_idx is None: + tmp_direction = (out_direction - 1) % 4 + while neighb_idx is None: + neighb_idx = neighbours[tmp_direction] + tmp_direction = (out_direction + 1) % 4 min_connection_dist = np.inf for dir in range(4): current_points = connection_points[neighb_idx][dir] @@ -776,13 +767,11 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, city_cells) - G.add_edge(current_node, neighb_idx, direction=direction, length=len(new_line)) + G.add_edge(current_node, neighb_idx, direction=out_direction, length=len(new_line)) G.add_edge(neighb_idx, current_node, direction=neighbour_direction, length=len(new_line)) all_paths.extend(new_line) - direction += 1 - return all_paths def _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, node_radius, rail_trans,