From 79a1e07677451b9839113129f92a28b00ee2bf7c Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 20 Aug 2019 14:57:11 +0200
Subject: [PATCH] realistic generator

---
 flatland/envs/generators.py                      | 16 ++++++++++------
 tests/test_flatland_env_sparse_rail_generator.py |  8 +++++---
 2 files changed, 15 insertions(+), 9 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 4589633e..d21c6eb1 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -543,7 +543,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
     return generator
 
 
-def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_track_back_bone=True):
+def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=4, two_track_back_bone=True):
     """
     Parameters
     -------
@@ -670,7 +670,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t
 
         np.random.seed(seed + num_resets)
 
-        max_n_track_seg = np.random.choice([3, 4, 5]) + int(two_track_back_bone)
+        max_n_track_seg = np.random.choice(np.arange(3, int(height / 2))) + int(two_track_back_bone)
         x_offsets = np.arange(0, height, max_n_track_seg).astype(int)
 
         agents_positions = []
@@ -680,8 +680,8 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t
         for off_set_loop in range(len(x_offsets)):
             off_set = x_offsets[off_set_loop]
             # second track
-            data = np.arange(3, width - 4, 3)
-            n_track_seg = np.random.choice(max_n_track_seg) + 1
+            data = np.arange(2, width - 2)
+            n_track_seg = np.random.choice([1,2])
 
             track_2 = False
             if two_track_back_bone:
@@ -817,6 +817,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t
                                   0 + int(two_track_back_bone_loop)))
 
             for nbr_track_loop in range(max_n_track_seg - 1):
+                n_track_seg = 1
                 if len(data) < 2 * n_track_seg + 1:
                     break
                 x = np.sort(np.random.choice(data, 2 * n_track_seg, False)).astype(int)
@@ -827,7 +828,8 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t
                     goal = (
                         max(0, min(off_set + nbr_track_loop + 1, height - 1)),
                         max(0, min(x[2 * x_loop + 1], width - 1)))
-                    d = np.arange(x[2 * x_loop] + 1, x[2 * x_loop + 1] - 1, 2)
+
+                    d = np.arange(x[2 * x_loop] + 1, x[2 * x_loop + 1] - 1)
                     data.extend(d)
 
                     new_path = connect_rail(rail_trans, rail_array, start, goal)
@@ -843,6 +845,8 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t
                     add_pos = (int((start[0] + goal[0]) / 2), int((2 * start[1] + goal[1]) / 3))
                     agents_targets.append(add_pos)
 
+        for off_set_loop in range(len(x_offsets)):
+            off_set = x_offsets[off_set_loop]
             pos_ys = np.random.choice(np.arange(width - 7) + 3, min(width - 7, add_max_dead_end), False)
             for pos_y in pos_ys:
                 pos_x = off_set + 1 + int(two_track_back_bone)
@@ -855,7 +859,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t
                         c = (pos_x, pos_y - k + 1)
                         ok &= grid_map.grid[c[0]][c[1]] == 0
                     if ok:
-                        if np.random.random() < 0.95:
+                        if np.random.random() < 0.5:
                             start_track = (pos_x, pos_y)
                             goal_track = (pos_x, pos_y - 2)
                             new_path = connect_rail(rail_trans, rail_array, start_track, goal_track)
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index b3ee348c..710ae6d5 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -7,14 +7,16 @@ from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 
+
 def test_realistic_rail_generator(vizualization_folder_name=None):
     for test_loop in range(50):
         print("test_loop", test_loop)
         num_agents = np.random.randint(10, 30)
         env = RailEnv(width=np.random.randint(40, 80),
                       height=np.random.randint(10, 20),
-                      rail_generator=realistic_rail_generator(nr_start_goal=num_agents + 1, seed=test_loop,
-                                                              add_max_dead_end=20,
+                      rail_generator=realistic_rail_generator(nr_start_goal=num_agents + 1,
+                                                              seed=test_loop,
+                                                              add_max_dead_end=4,
                                                               two_track_back_bone=test_loop % 2 == 0),
                       number_of_agents=num_agents,
                       obs_builder_object=GlobalObsForRailEnv())
@@ -51,4 +53,4 @@ def test_sparse_rail_generator():
     env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
     env_renderer.close_window()
 
-#test_realistic_rail_generator("./../rendering/")
+test_realistic_rail_generator("./../rendering/")
-- 
GitLab