From a481cdaa52987ae6e953e27de9fcd28463c36ba3 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Mon, 30 Sep 2019 14:48:50 -0400
Subject: [PATCH] updated inner city construciton algorithm

---
 examples/flatland_2_0_example.py        |  2 +-
 flatland/envs/grid4_generators_utils.py |  2 +-
 flatland/envs/rail_generators.py        | 44 +++++++++++++++----------
 3 files changed, 29 insertions(+), 19 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 91e3c6ce..5c92fd1c 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -31,7 +31,7 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 env = RailEnv(width=50,
               height=50,
               rail_generator=sparse_rail_generator(num_cities=20,  # Number of cities in map (where train stations are)
-                                                   seed=0,  # Random seed
+                                                   seed=1,  # Random seed
                                                    grid_mode=False,
                                                    max_inter_city_rails=2,
                                                    max_tracks_in_city=4,
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index 166094aa..d0e75e27 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -90,7 +90,7 @@ def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
 
 
 def connect_cities(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
-                   start: IntVector2D, end: IntVector2D, forbidden_cells,
+                   start: IntVector2D, end: IntVector2D, forbidden_cells=None,
                    a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
     return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function,
                                    forbidden_cells)
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index c7d97803..d6e14459 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -783,23 +783,33 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
 
             # This part only works if we have keep same number of connection points for both directions
             # Also only works with two connection direction at each city
-            for boarder in range(4):
-                opposite_boarder = (boarder + 2) % 4
-                for track_id in range(len(inner_connection_points[current_city][boarder])):
-                    if track_id % 2 == 0:
-                        source = inner_connection_points[current_city][boarder][track_id]
-                        for target in inner_connection_points[current_city][opposite_boarder]:
-                            current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder)
-                            if target in all_outer_connection_points and source in \
-                                all_outer_connection_points and len(through_path_cells[current_city]) < 1:
-                                through_path_cells[current_city].extend(current_track)
-                    else:
-                        source = inner_connection_points[current_city][opposite_boarder][track_id]
-                        for target in inner_connection_points[current_city][boarder]:
-                            current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder)
-                            if target in all_outer_connection_points and source in \
-                                all_outer_connection_points and len(through_path_cells[current_city]) < 1:
-                                through_path_cells[current_city].extend(current_track)
+            for i in range(4):
+                if len(inner_connection_points[current_city][i]) > 0:
+                    boarder = i
+                    break
+
+            opposite_boarder = (boarder + 2) % 4
+            boarder_one = inner_connection_points[current_city][boarder]
+            boarder_two = inner_connection_points[current_city][opposite_boarder]
+            connect_cities(rail_trans, grid_map, boarder_one[0], boarder_one[-1])
+            connect_cities(rail_trans, grid_map, boarder_two[0], boarder_two[-1])
+
+            for track_id in range(len(inner_connection_points[current_city][boarder])):
+                if track_id % 2 == 0:
+                    source = inner_connection_points[current_city][boarder][track_id]
+                    target = inner_connection_points[current_city][opposite_boarder][track_id]
+                    current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder)
+                    if target in all_outer_connection_points and source in \
+                        all_outer_connection_points and len(through_path_cells[current_city]) < 1:
+                        through_path_cells[current_city].extend(current_track)
+                else:
+                    source = inner_connection_points[current_city][opposite_boarder][track_id]
+                    target = inner_connection_points[current_city][boarder][track_id]
+
+                    current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder)
+                    if target in all_outer_connection_points and source in \
+                        all_outer_connection_points and len(through_path_cells[current_city]) < 1:
+                        through_path_cells[current_city].extend(current_track)
 
         return through_path_cells
 
-- 
GitLab