From 9ed202f09ba72acb8706c39b2f505a2a855cfcb6 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Mon, 19 Aug 2019 08:58:00 +0200
Subject: [PATCH] realitic rail generator finish and polished

---
 flatland/envs/generators.py                   | 95 ++++++++++++++-----
 flatland/utils/graphics_pil.py                | 10 +-
 flatland/utils/rendertools.py                 |  8 +-
 ...test_flatland_env_sparse_rail_generator.py | 18 +++-
 4 files changed, 95 insertions(+), 36 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 2826ee51..ac9b0c22 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -543,7 +543,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
     return generator
 
 
-def realistic_rail_generator(nr_start_goal=1, seed=0):
+def realistic_rail_generator(nr_start_goal=1, seed=0,max_add_dead_end = 3):
     """
     Parameters
     -------
@@ -746,33 +746,80 @@ def realistic_rail_generator(nr_start_goal=1, seed=0):
                 data = []
                 for x_loop in range(int(len(x) / 2)):
                     start = (
-                        max(0, min(off_set + nbr_track_loop + 1, height - 1)), max(0, min(x[2 * x_loop], width - 1)))
+                        max(0, min(off_set + nbr_track_loop + 1, height - 1)),
+                        max(0, min(x[2 * x_loop], width - 1)))
                     goal = (
                         max(0, min(off_set + nbr_track_loop + 1, height - 1)),
                         max(0, min(x[2 * x_loop + 1], width - 1)))
-                    d = np.arange(x[2 * x_loop] + 1, x[2 * x_loop + 1] - 1, 2)
-                    data.extend(d)
-
-                    new_path = connect_rail(rail_trans, rail_array, start, goal)
-                    if len(new_path) > 0:
-                        c = (off_set + nbr_track_loop, x[2 * x_loop] + 1)
-                        make_switch_e_w(width, height, grid_map, c)
-                        c = (off_set + nbr_track_loop, x[2 * x_loop + 1] + 1)
-                        make_switch_w_e(width, height, grid_map, c)
-
-                    add_pos = (int((start[0] + goal[0]) / 2), int((start[1] + goal[1]) / 2))
-                    if nbr_track_loop % 2 == 0:
-                        agents_positions_forward.append(add_pos)
-                        agents_directions_forward.append(([1, 3][off_set_loop % 2]))
-                        idx_forward.append(idx_target)
-                    else:
-                        agents_positions_backward.append(add_pos)
-                        agents_directions_backward.append(([1, 3][off_set_loop % 2]))
-                        idx_backward.append(idx_target)
 
-                    add_pos = (int((start[0] + goal[0]) / 2), int((2 * start[1] + goal[1]) / 3), idx_target)
-                    agents_targets.append(add_pos)
-                    idx_target += 1
+                    if (off_set + nbr_track_loop + 1 == start[0]) and (off_set + nbr_track_loop + 1 == goal[0]):
+                        d = np.arange(x[2 * x_loop] + 1, x[2 * x_loop + 1] - 1, 2)
+                        data.extend(d)
+
+                        new_path = connect_rail(rail_trans, rail_array, start, goal)
+                        if len(new_path) > 0:
+                            c = (off_set + nbr_track_loop, x[2 * x_loop] + 1)
+                            make_switch_e_w(width, height, grid_map, c)
+                            c = (off_set + nbr_track_loop, x[2 * x_loop + 1] + 1)
+                            make_switch_w_e(width, height, grid_map, c)
+
+                            add_pos = (int((start[0] + goal[0]) / 2), int((start[1] + goal[1]) / 2))
+                            if off_set_loop % 2 == 0:
+                                agents_positions_forward.append(add_pos)
+                                agents_directions_forward.append(([1, 3][nbr_track_loop % 2]))
+                                idx_forward.append(idx_target)
+                            else:
+                                agents_positions_backward.append(add_pos)
+                                agents_directions_backward.append(([3, 1][nbr_track_loop % 2]))
+                                idx_backward.append(idx_target)
+
+                            add_pos = (int((start[0] + goal[0]) / 2), int((2 * start[1] + goal[1]) / 3), idx_target)
+                            agents_targets.append(add_pos)
+                            idx_target += 1
+
+            # add dead-end
+            if True:
+                n = int(np.random.choice(np.arange(max_add_dead_end), 1)[0])
+                for pos_y in np.random.choice(np.arange(width - 7) + 3, n):
+                    pos_x = off_set
+                    pos_x1 = max(0, min(pos_x + 1, height - 1))
+                    if np.random.random() > 0.5:
+                        if pos_x + 1 < height - 1:
+                            start_track = (pos_x1, pos_y)
+                            goal_track = (pos_x1, pos_y + 1)
+                            ok = True
+                            for k in range(4):
+                                ok &= grid_map.grid[pos_x1][pos_y + (k-1) ] == 0
+                            if ok:
+                                new_path = connect_rail(rail_trans, rail_array, start_track, goal_track)
+                                if len(new_path) > 0:
+                                    c = (pos_x1 - 1, pos_y + 1)
+                                    make_switch_e_w(width, height, grid_map, c)
+                                    add_pos = goal_track  # (int((start_track[0] + goal_track[0]) / 2), int((start_track[1] + goal_track[1]) / 2))
+                                    agents_positions_forward.append(add_pos)
+                                    agents_directions_forward.append(3)
+                                    idx_forward.append(idx_target)
+                                    agents_targets.append((goal_track[0], goal_track[1], idx_target))
+                                    idx_target += 1
+                    else:
+                        pos_x = max(0, min(pos_x + 1, height - 1))
+                        if pos_x + 1 < height - 1:
+                            start_track = (pos_x1, pos_y - 1)
+                            goal_track = (pos_x1, pos_y - 2)
+                            ok = True
+                            for k in range(4):
+                                ok &= grid_map.grid[pos_x1][pos_y - k] == 0
+                            if ok:
+                                new_path = connect_rail(rail_trans, rail_array, start_track, goal_track)
+                                if len(new_path) > 0:
+                                    c = (pos_x1 - 1, pos_y)
+                                    make_switch_w_e(width, height, grid_map, c)
+                                    add_pos = goal_track  # (int((start_track[0] + goal_track[0]) / 2), int((start_track[1] + goal_track[1]) / 2))
+                                    agents_positions_backward.append(add_pos)
+                                    agents_directions_backward.append(1)
+                                    idx_backward.append(idx_target)
+                                    agents_targets.append((goal_track[0], goal_track[1], idx_target))
+                                    idx_target += 1
 
         agents_position = []
         agents_target = []
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index d8dbbfff..6333909e 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -41,7 +41,7 @@ class PILGL(GraphicsLayer):
     SELECTED_AGENT_LAYER = 4
     SELECTED_TARGET_LAYER = 5
 
-    def __init__(self, width, height, jupyter=False):
+    def __init__(self, width, height, jupyter=False, screen_width=800,screen_height=600):
         self.yxBase = (0, 0)
         self.linewidth = 4
         self.n_agent_colors = 1  # overridden in loadAgent
@@ -57,8 +57,8 @@ class PILGL(GraphicsLayer):
             #       way to compute the screen width and height
             #       In the meantime, we are harcoding the 800x600
             #       assumption
-            self.screen_width = 800
-            self.screen_height = 600
+            self.screen_width = screen_width
+            self.screen_height = screen_height
             w = (self.screen_width - self.width - 10) / (self.width + 1 + self.linewidth)
             h = (self.screen_height - self.height - 10) / (self.height + 1 + self.linewidth)
             self.nPixCell = int(max(1, np.ceil(min(w, h))))
@@ -263,9 +263,9 @@ class PILGL(GraphicsLayer):
 
 
 class PILSVG(PILGL):
-    def __init__(self, width, height, jupyter=False):
+    def __init__(self, width, height, jupyter=False, screen_width=800,screen_height=600):
         oSuper = super()
-        oSuper.__init__(width, height, jupyter)
+        oSuper.__init__(width, height, jupyter,screen_width,screen_height)
 
         self.lwAgents = []
         self.agents_prev = []
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 8974126a..00c1e1b9 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -39,7 +39,7 @@ class RenderTool(object):
     theta = np.linspace(0, np.pi / 2, 5)
     arc = array([np.cos(theta), np.sin(theta)]).T  # from [1,0] to [0,1]
 
-    def __init__(self, env, gl="PILSVG", jupyter=False, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND):
+    def __init__(self, env, gl="PILSVG", jupyter=False, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, screen_width=800,screen_height=600):
         self.env = env
         self.frame_nr = 0
         self.start_time = time.time()
@@ -48,12 +48,12 @@ class RenderTool(object):
         self.agent_render_variant = agent_render_variant
 
         if gl == "PIL":
-            self.gl = PILGL(env.width, env.height, jupyter)
+            self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width,screen_height=screen_height)
         elif gl == "PILSVG":
-            self.gl = PILSVG(env.width, env.height, jupyter)
+            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width,screen_height=screen_height)
         else:
             print("[", gl, "] not found, switch to PILSVG")
-            self.gl = PILSVG(env.width, env.height, jupyter)
+            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width,screen_height=screen_height)
 
         self.new_rail = True
         self.update_background()
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index 92744080..87b5ab4a 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -1,3 +1,4 @@
+import os
 import time
 
 import numpy as np
@@ -5,10 +6,10 @@ import numpy as np
 from flatland.envs.generators import sparse_rail_generator, realistic_rail_generator
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
-from flatland.utils.rendertools import RenderTool
+from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 
 
-def test_realistic_rail_generator():
+def test_realistic_rail_generator(vizualization_folder_name=None):
     for test_loop in range(20):
         num_agents = np.random.randint(10, 30)
         env = RailEnv(width=np.random.randint(40, 80),
@@ -17,8 +18,16 @@ def test_realistic_rail_generator():
                       number_of_agents=num_agents,
                       obs_builder_object=GlobalObsForRailEnv())
         # reset to initialize agents_static
-        env_renderer = RenderTool(env, gl="PILSVG", )
+        env_renderer = RenderTool(env, gl="PILSVG", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, screen_height=1200,
+                                  screen_width=1600)
         env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
+
+        if vizualization_folder_name is not None:
+            env_renderer.gl.save_image(
+                os.path.join(
+                    vizualization_folder_name,
+                    "flatland_frame_{:04d}.png".format(test_loop)
+                ))
         env_renderer.close_window()
 
 
@@ -39,3 +48,6 @@ def test_sparse_rail_generator():
     env_renderer = RenderTool(env, gl="PILSVG", )
     env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
     time.sleep(2)
+
+
+test_realistic_rail_generator(vizualization_folder_name="./rendering")
-- 
GitLab