From 7848676a406fc5be3759b3e13f69b3ef7d7489f5 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 6 Oct 2019 10:09:00 -0400
Subject: [PATCH] fixed train station placement in cities to utilize all tracks

---
 examples/flatland_2_0_example.py     | 10 +++++-----
 flatland/envs/rail_generators.py     |  8 +++++---
 flatland/envs/schedule_generators.py |  2 +-
 3 files changed, 11 insertions(+), 9 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index a5a5946b..3ad3a7ec 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -30,17 +30,17 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
                     1. / 3.: 0.25,  # Slow commuter train
                     1. / 4.: 0.25}  # Slow freight train
 
-env = RailEnv(width=50,
-              height=50,
-              rail_generator=sparse_rail_generator(max_num_cities=10,
+env = RailEnv(width=100,
+              height=100,
+              rail_generator=sparse_rail_generator(max_num_cities=20,
                                                    # Number of cities in map (where train stations are)
                                                    seed=1,  # Random seed
                                                    grid_mode=False,
                                                    max_rails_between_cities=3,
-                                                   max_rails_in_city=2,
+                                                   max_rails_in_city=8,
                                                    ),
               schedule_generator=sparse_schedule_generator(speed_ration_map),
-              number_of_agents=10,
+              number_of_agents=100,
               stochastic_data=stochastic_data,  # Malfunction data generator
               obs_builder_object=TreeObservation,
               remove_agents_at_target=True
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 2f974ae9..6ec01616 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -822,22 +822,24 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                 source = inner_connection_points[current_city][boarder][track_id]
                 target = inner_connection_points[current_city][opposite_boarder][track_id]
                 current_track = connect_straight_line_in_grid_map(grid_map, source, target, rail_trans)
-
+                free_rails[current_city].append(current_track)
             for track_id in range(nr_of_connection_points):
                 source = inner_connection_points[current_city][boarder][track_id]
                 target = inner_connection_points[current_city][opposite_boarder][track_id]
+
+                # Connect parallel tracks with each other
                 fix_inner_nodes(
                     grid_map, source, rail_trans)
                 fix_inner_nodes(
                     grid_map, target, rail_trans)
+
+                # Connect outer tracks to inner tracks
                 if start_idx <= track_id < start_idx + number_of_out_rails:
                     source_outer = outer_connection_points[current_city][boarder][track_id - start_idx]
                     target_outer = outer_connection_points[current_city][opposite_boarder][track_id - start_idx]
                     connect_straight_line_in_grid_map(grid_map, source, source_outer, rail_trans)
                     connect_straight_line_in_grid_map(grid_map, target, target_outer, rail_trans)
 
-
-                free_rails[current_city].append(current_track)
         return free_rails
 
     def _set_trainstation_positions(city_positions: IntVector2DArray, city_radius: int,
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index 5cd44cf0..07f10cec 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -78,11 +78,11 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
             target_city = agent_start_targets_cities[agent_idx][1]
             start = random.choice(train_stations[start_city])
             target = random.choice(train_stations[target_city])
+
             while start[1] % 2 != 0:
                 start = random.choice(train_stations[start_city])
             while target[1] % 2 != 1:
                 target = random.choice(train_stations[target_city])
-
             agent_orientation = (agent_start_targets_cities[agent_idx][2] + 2 * start[1]) % 4
             if not rail.check_path_exists(start[0], agent_orientation, target[0]):
                 agent_orientation = (agent_orientation + 2) % 4
-- 
GitLab