From 853dd42c29e9ed6e646a4f14536dda85d9b58940 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 20 Aug 2019 09:43:30 +0200
Subject: [PATCH] realistic generator supports single or two track backbone

---
 flatland/envs/generators.py                    | 18 ++++++++++++++++--
 .../test_flatland_env_sparse_rail_generator.py |  2 ++
 2 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 1dedec19..24671d79 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -705,6 +705,20 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t
             goal_track = (off_set, width - 1 - int((off_set_loop) % 2) * int(two_track_back_bone) * int(track_2))
             new_path = connect_rail(rail_trans, rail_array, start_track, goal_track)
 
+            if track_2:
+                if np.random.random() < 0.75:
+                    c = (off_set, 3)
+                    if np.random.random() < 0.5:
+                        make_switch_e_w(width, height, grid_map, c)
+                    else:
+                        make_switch_w_e(width, height, grid_map, c)
+                if np.random.random() < 0.5:
+                    c = (off_set, width - 3)
+                    if np.random.random() < 0.5:
+                        make_switch_e_w(width, height, grid_map, c)
+                    else:
+                        make_switch_w_e(width, height, grid_map, c)
+
             # track one (full track : left right)
             for two_track_back_bone_loop in range(1 + int(track_2) * int(two_track_back_bone)):
                 if off_set_loop > 0:
@@ -832,7 +846,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t
                     agents_targets.append(add_pos)
                     idx_target += 1
 
-            for pos_y in np.random.choice(np.arange(width - 7) + 3, add_max_dead_end, False):
+            for pos_y in np.random.choice(np.arange(width - 7) + 3, min(width - 7, add_max_dead_end), False):
                 pos_x = off_set + 1 + int(two_track_back_bone)
                 if pos_x < height - 1:
                     ok = True
@@ -1054,7 +1068,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                                     width - 1)
                 while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[
                     trainstation_node] or \
-                        rail_array[(station_x, station_y)] != 0:
+                    rail_array[(station_x, station_y)] != 0:
                     station_x = np.clip(
                         node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
                         0,
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index eac83381..3ac0636e 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -50,3 +50,5 @@ def test_sparse_rail_generator():
     # reset to initialize agents_static
     env_renderer = RenderTool(env, gl="PILSVG", )
     env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
+
+test_realistic_rail_generator()
-- 
GitLab