From b66688da63765dadabb7d9027c4cb8f87ed11928 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Mon, 30 Sep 2019 14:20:25 -0400
Subject: [PATCH] updated limits for city center positioning

---
 examples/flatland_2_0_example.py |  4 ++--
 flatland/envs/rail_generators.py | 33 +++++++++++++++++---------------
 2 files changed, 20 insertions(+), 17 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 961e36a8..91e3c6ce 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -28,8 +28,8 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
                     1. / 3.: 0.25,  # Slow commuter train
                     1. / 4.: 0.25}  # Slow freight train
 
-env = RailEnv(width=100,
-              height=100,
+env = RailEnv(width=50,
+              height=50,
               rail_generator=sparse_rail_generator(num_cities=20,  # Number of cities in map (where train stations are)
                                                    seed=0,  # Random seed
                                                    grid_mode=False,
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 7ea93715..31eef79a 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -616,8 +616,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
             tries = 0
 
             while to_close:
-                x_tmp = node_radius + 1 + np.random.randint(height - 2 * node_radius - 1)
-                y_tmp = node_radius + 1 + np.random.randint(width - 2 * node_radius - 1)
+                x_tmp = node_radius + 1 + np.random.randint(height - 2 * (node_radius - 1))
+                y_tmp = node_radius + 1 + np.random.randint(width - 2 * (node_radius - 1))
                 to_close = False
                 # Check distance to nodes
                 for node_pos in node_positions:
@@ -770,19 +770,22 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                                            sublist]
             city_boarder = _city_boarder(node_positions[current_city], node_radius)
 
-            random_boarders = np.random.choice(np.arange(4), 4, False)
-            # TODO: Only look at the relevant boarders (Only two at the moment)
-            for boarder in random_boarders:
-                for source in inner_connection_points[current_city][boarder]:
-                    for other_boarder in random_boarders:
-                        if boarder != other_boarder and len(inner_connection_points[current_city][other_boarder]) > 0:
-                            for target in inner_connection_points[current_city][other_boarder]:
-                                current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder)
-                                if target in all_outer_connection_points and source in \
-                                    all_outer_connection_points and len(through_path_cells[current_city]) < 1:
-                                    through_path_cells[current_city].extend(current_track)
-                        else:
-                            continue
+            # This part only works if we have keep same number of connection points for both directions
+            # Also only works with two connection direction at each city
+            for boarder in range(4):
+                opposite_boarder = (boarder + 2) % 4
+                for track_id in range(len(inner_connection_points[current_city][boarder])):
+                    if track_id % 2 == 0:
+                        source = inner_connection_points[current_city][boarder][track_id]
+                        for target in inner_connection_points[current_city][opposite_boarder]:
+                            current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder)
+                            if target in all_outer_connection_points and source in \
+                                all_outer_connection_points and len(through_path_cells[current_city]) < 1:
+                                through_path_cells[current_city].extend(current_track)
+                    else:
+                        source = inner_connection_points[current_city][opposite_boarder][track_id]
+                        for target in inner_connection_points[current_city][boarder]:
+                            current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder)
 
         return through_path_cells
 
-- 
GitLab