From dc950fd6875fe46ca71bad05f279d008fb535ffd Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Mon, 30 Sep 2019 20:14:46 -0400
Subject: [PATCH] fixed track numbering function

---
 flatland/envs/rail_generators.py     | 29 +++++++++++++++++-----------
 flatland/envs/schedule_generators.py |  1 -
 2 files changed, 18 insertions(+), 12 deletions(-)

diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index ac898d43..c2f4f438 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -768,18 +768,20 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                 for tmp_out_connection_point in connection_points[current_node][direction]:
                     # Find closest connection point
                     min_connection_dist = np.inf
-                    all_neighb_connection_points = [item for sublist in connection_points[neighb_idx] for item in
-                                                    sublist]
-
-                    for tmp_in_connection_point in all_neighb_connection_points:
-                        tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point,
+                    for dir in range(4):
+                        current_points = connection_points[neighb_idx][dir]
+                        for tmp_in_connection_point in current_points:
+                            tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point,
                                                     metric="Manhattan")
-                        if tmp_dist < min_connection_dist:
-                            min_connection_dist = tmp_dist
-                            neighb_connection_point = tmp_in_connection_point
+                            if tmp_dist < min_connection_dist:
+                                min_connection_dist = tmp_dist
+                                neighb_connection_point = tmp_in_connection_point
+                                neighbour_direction = dir
                     new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point,
                                               city_cells)
                     G.add_edge(current_node, neighb_idx, direction=direction, length=len(new_line))
+                    G.add_edge(neighb_idx, current_node, direction=neighbour_direction, length=len(new_line))
+
                     all_paths.extend(new_line)
                 direction += 1
 
@@ -1003,8 +1005,13 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
         :return:
         """
         if city_orientation % 2 == 0:
-            return np.abs(city_position[1] - position[1]) % 2
+            if city_position[1] - position[1] < 0:
+                return np.abs(city_position[1] - position[1]) % 2
+            else:
+                return (np.abs(city_position[1] - position[1]) + 1) % 2
         else:
-            return np.abs(city_position[0] - position[0]) % 2
-
+            if city_position[0] - position[0] < 0:
+                return np.abs(city_position[0] - position[0]) % 2
+            else:
+                return (np.abs(city_position[0] - position[0]) + 1) % 2
     return generator
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index cd2f3996..3a378387 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -88,7 +88,6 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
                 track_to_use = 0
             else:
                 track_to_use = 1
-
             for i in range(len(train_stations[current_start_node])):
                 if train_stations[current_start_node][i][1] == track_to_use:
                     start_station_idx = i
-- 
GitLab