From 555badf440fa401587b25227760b9178de85e7eb Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Thu, 26 Sep 2019 17:24:13 -0400
Subject: [PATCH] fixed index error in city connection algorithm

---
 examples/flatland_2_0_example.py | 2 +-
 flatland/envs/rail_generators.py | 8 +++++---
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 9cab89e8..12f84422 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -39,7 +39,7 @@ env = RailEnv(width=50,
                                                    num_neighb=3,  # Number of connections to other cities/intersections
                                                    seed=15,  # Random seed
                                                    grid_mode=True,
-                                                   nr_parallel_tracks=10,
+                                                   nr_parallel_tracks=1,
                                                    connection_points_per_side=2,
                                                    max_nr_connection_directions=4,
                                                    ),
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index e72271ad..ee218dc6 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -742,14 +742,15 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
         :return:
         """
         boarder_connections = set()
-        # Start at some node
         for current_node in np.arange(len(node_positions)):
             direction = 0
             for nbr_connection_points in connection_info[current_node]:
                 if nbr_connection_points > 0:
                     neighb_idx = _closest_neigh_in_direction(current_node, direction, node_positions)
-                    print(current_node, direction, neighb_idx, connection_info[current_node])
+                    print(current_node, node_positions[current_node], direction, neighb_idx,
+                          connection_info[current_node])
                 else:
+                    direction += 1
                     continue
 
                 if neighb_idx is not None:
@@ -771,6 +772,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                         boarder_connections.add((tmp_out_connection_point, current_node))
                         boarder_connections.add((neighb_connection_point, neighb_idx))
                 direction += 1
+        return boarder_connections
 
 
     def _build_cities(node_positions, connection_points, rail_trans, grid_map):
@@ -923,7 +925,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                     return neighb
 
             if direction == 3:
-                if node_positions[neighb][0] < node_positions[current_node][0] and distance_0 <= distance_1:
+                if node_positions[neighb][1] < node_positions[current_node][1] and distance_0 <= distance_1:
                     return neighb
         return None
 
-- 
GitLab