From c9207ec0e93f3371e8c45c838a87f1862f5b8aba Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Sat, 28 Sep 2019 10:18:49 -0400 Subject: [PATCH] updated how closest neighbours are found. Now always looking at directions similar to initial try --- examples/flatland_2_0_example.py | 2 +- flatland/envs/rail_generators.py | 58 ++++++++++++++------------------ 2 files changed, 26 insertions(+), 34 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 0099b1cd..7ed5d27c 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -33,7 +33,7 @@ env = RailEnv(width=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) min_node_dist=12, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center - seed=0, # Random seed + seed=12, # Random seed grid_mode=False, max_inter_city_rails=2, tracks_in_city=5, diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 226b7f89..d1911d4b 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -728,23 +728,19 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, for current_node in np.arange(len(node_positions)): direction = 0 connected_to_city = [] + neighbours = _closest_neigh_in_direction(current_node, node_positions) for nbr_connection_points in connection_info[current_node]: if nbr_connection_points > 0: - neighb_idx = _closest_neigh_in_direction(current_node, direction, node_positions) + neighb_idx = neighbours[direction] else: direction += 1 continue - if neighb_idx is None or neighb_idx in connected_to_city: - node_dist = [] - for av_node in node_positions: - node_dist.append(distance_on_rail(node_positions[current_node], av_node)) - i = 1 - neighbours = np.argsort(node_dist) - neighb_idx = neighbours[i] - while neighb_idx in connected_to_city: - i += 1 - neighb_idx = neighbours[i] + # If no closest neighbour was found look for the next one clock wise to avoid connecting to previous node + tmp_direction = (direction + 1) % 4 + while neighb_idx is None: + neighb_idx = neighbours[tmp_direction] + tmp_direction = (tmp_direction - 1) % 4 connected_to_city.append(neighb_idx) for tmp_out_connection_point in connection_points[current_node][direction]: @@ -882,33 +878,29 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, for cell in rails_to_fix: grid_map.fix_transitions(cell) - def _closest_neigh_in_direction(current_node, direction, node_positions): - # Sort available neighbors according to their distance. - + def _closest_neigh_in_direction(current_node, node_positions): + """ + Returns indices of closest neighbours in every direction NESW + :param current_node: Index of node in node_positions list + :param node_positions: list of all points being considered + :return: list of index of closest neighbours in all directions + """ node_dist = [] + closest_neighb = [None for i in range(4)] for av_node in range(len(node_positions)): node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node])) sorted_neighbours = np.argsort(node_dist) - + direction_set = 0 for neighb in sorted_neighbours[1:]: - distance_0 = np.abs(node_positions[current_node][0] - node_positions[neighb][0]) - distance_1 = np.abs(node_positions[current_node][1] - node_positions[neighb][1]) - if direction == 0: - if node_positions[neighb][0] < node_positions[current_node][0] and distance_1 <= distance_0: - return neighb - - if direction == 1: - if node_positions[neighb][1] > node_positions[current_node][1] and distance_0 <= distance_1: - return neighb - - if direction == 2: - if node_positions[neighb][0] > node_positions[current_node][0] and distance_1 <= distance_0: - return neighb - - if direction == 3: - if node_positions[neighb][1] < node_positions[current_node][1] and distance_0 <= distance_1: - return neighb - return None + direction_to_neighb = direction_to_point(node_positions[current_node], node_positions[neighb]) + if closest_neighb[direction_to_neighb] == None: + closest_neighb[direction_to_neighb] = neighb + direction_set += 1 + + if direction_set == 4: + return closest_neighb + + return closest_neighb def argsort(seq): # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python -- GitLab