From c9207ec0e93f3371e8c45c838a87f1862f5b8aba Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Sat, 28 Sep 2019 10:18:49 -0400
Subject: [PATCH] updated how closest neighbours are found. Now always looking
 at directions similar to initial try

---
 examples/flatland_2_0_example.py |  2 +-
 flatland/envs/rail_generators.py | 58 ++++++++++++++------------------
 2 files changed, 26 insertions(+), 34 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 0099b1cd..7ed5d27c 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -33,7 +33,7 @@ env = RailEnv(width=50,
               rail_generator=sparse_rail_generator(num_cities=9,  # Number of cities in map (where train stations are)
                                                    min_node_dist=12,  # Minimal distance of nodes
                                                    node_radius=4,  # Proximity of stations to city center
-                                                   seed=0,  # Random seed
+                                                   seed=12,  # Random seed
                                                    grid_mode=False,
                                                    max_inter_city_rails=2,
                                                    tracks_in_city=5,
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 226b7f89..d1911d4b 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -728,23 +728,19 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
         for current_node in np.arange(len(node_positions)):
             direction = 0
             connected_to_city = []
+            neighbours = _closest_neigh_in_direction(current_node, node_positions)
             for nbr_connection_points in connection_info[current_node]:
                 if nbr_connection_points > 0:
-                    neighb_idx = _closest_neigh_in_direction(current_node, direction, node_positions)
+                    neighb_idx = neighbours[direction]
                 else:
                     direction += 1
                     continue
 
-                if neighb_idx is None or neighb_idx in connected_to_city:
-                    node_dist = []
-                    for av_node in node_positions:
-                        node_dist.append(distance_on_rail(node_positions[current_node], av_node))
-                    i = 1
-                    neighbours = np.argsort(node_dist)
-                    neighb_idx = neighbours[i]
-                    while neighb_idx in connected_to_city:
-                        i += 1
-                        neighb_idx = neighbours[i]
+                # If no closest neighbour was found look for the next one clock wise to avoid connecting to previous node
+                tmp_direction = (direction + 1) % 4
+                while neighb_idx is None:
+                    neighb_idx = neighbours[tmp_direction]
+                    tmp_direction = (tmp_direction - 1) % 4
 
                 connected_to_city.append(neighb_idx)
                 for tmp_out_connection_point in connection_points[current_node][direction]:
@@ -882,33 +878,29 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
         for cell in rails_to_fix:
             grid_map.fix_transitions(cell)
 
-    def _closest_neigh_in_direction(current_node, direction, node_positions):
-        # Sort available neighbors according to their distance.
-
+    def _closest_neigh_in_direction(current_node, node_positions):
+        """
+        Returns indices of closest neighbours in every direction NESW
+        :param current_node: Index of node in node_positions list
+        :param node_positions: list of all points being considered
+        :return: list of index of closest neighbours in all directions
+        """
         node_dist = []
+        closest_neighb = [None for i in range(4)]
         for av_node in range(len(node_positions)):
             node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node]))
         sorted_neighbours = np.argsort(node_dist)
-
+        direction_set = 0
         for neighb in sorted_neighbours[1:]:
-            distance_0 = np.abs(node_positions[current_node][0] - node_positions[neighb][0])
-            distance_1 = np.abs(node_positions[current_node][1] - node_positions[neighb][1])
-            if direction == 0:
-                if node_positions[neighb][0] < node_positions[current_node][0] and distance_1 <= distance_0:
-                    return neighb
-
-            if direction == 1:
-                if node_positions[neighb][1] > node_positions[current_node][1] and distance_0 <= distance_1:
-                    return neighb
-
-            if direction == 2:
-                if node_positions[neighb][0] > node_positions[current_node][0] and distance_1 <= distance_0:
-                    return neighb
-
-            if direction == 3:
-                if node_positions[neighb][1] < node_positions[current_node][1] and distance_0 <= distance_1:
-                    return neighb
-        return None
+            direction_to_neighb = direction_to_point(node_positions[current_node], node_positions[neighb])
+            if closest_neighb[direction_to_neighb] == None:
+                closest_neighb[direction_to_neighb] = neighb
+                direction_set += 1
+
+            if direction_set == 4:
+                return closest_neighb
+
+        return closest_neighb
 
     def argsort(seq):
         # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
-- 
GitLab