From d6a4565967feb3f4a1a92d168b1401c59a367fee Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Tue, 1 Oct 2019 09:07:02 -0400
Subject: [PATCH] fixed bug in connection

---
 examples/flatland_2_0_example.py | 2 +-
 flatland/envs/rail_generators.py | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 7916c4e5..3ca6ab9d 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -30,7 +30,7 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 
 env = RailEnv(width=50,
               height=50,
-              rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map (where train stations are)
+              rail_generator=sparse_rail_generator(num_cities=2,  # Number of cities in map (where train stations are)
                                                    seed=1,  # Random seed
                                                    grid_mode=False,
                                                    max_inter_city_rails=2,
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index ba12736f..a1205580 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -753,7 +753,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                         tmp_direction = (out_direction - 1) % 4
                     while neighb_idx is None:
                         neighb_idx = neighbours[tmp_direction]
-                        tmp_direction = (out_direction + 1) % 4
+                        tmp_direction = (tmp_direction + 1) % 4
                     min_connection_dist = np.inf
                     for dir in range(4):
                         current_points = connection_points[neighb_idx][dir]
@@ -946,7 +946,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
 
             if direction_set == 4:
                 return closest_neighb
-
+        print(closest_neighb)
         return closest_neighb
 
     def argsort(seq):
-- 
GitLab