From a9716d8a66ccb70c2f02d4862c54c955864f74ed Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Mon, 19 Aug 2019 11:31:39 -0400
Subject: [PATCH] added realistic_mode for less random levels added more
 complex intersections

---
 flatland/core/transition_map.py               |   2 +
 flatland/envs/generators.py                   | 144 ++++++++----------
 flatland/utils/graphics_pil.py                |   6 +-
 flatland/utils/rendertools.py                 |   9 +-
 ...test_flatland_env_sparse_rail_generator.py |  20 +--
 5 files changed, 80 insertions(+), 101 deletions(-)

diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 8c0bcb6d..048593d9 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -412,6 +412,8 @@ class GridTransitionMap(TransitionMap):
         grcMax = self.grid.shape
 
         # loop over available outbound directions (indices) for rcPos
+        self.set_transitions(rcPos, 0)
+
         incomping_connections = np.zeros(4)
         for iDirOut in np.arange(4):
             gdRC = gDir2dRC[iDirOut]  # row,col increment
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index cbce5062..62e0616c 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -543,7 +543,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
     return generator
 
 
-def realistic_rail_generator(nr_start_goal=1, seed=0,max_add_dead_end = 7):
+def realistic_rail_generator(nr_start_goal=1, seed=0):
     """
     Parameters
     -------
@@ -746,80 +746,33 @@ def realistic_rail_generator(nr_start_goal=1, seed=0,max_add_dead_end = 7):
                 data = []
                 for x_loop in range(int(len(x) / 2)):
                     start = (
-                        max(0, min(off_set + nbr_track_loop + 1, height - 1)),
-                        max(0, min(x[2 * x_loop], width - 1)))
+                        max(0, min(off_set + nbr_track_loop + 1, height - 1)), max(0, min(x[2 * x_loop], width - 1)))
                     goal = (
                         max(0, min(off_set + nbr_track_loop + 1, height - 1)),
                         max(0, min(x[2 * x_loop + 1], width - 1)))
-
-                    if (off_set + nbr_track_loop + 1 == start[0]) and (off_set + nbr_track_loop + 1 == goal[0]):
-                        d = np.arange(x[2 * x_loop] + 1, x[2 * x_loop + 1] - 1, 2)
-                        data.extend(d)
-
-                        new_path = connect_rail(rail_trans, rail_array, start, goal)
-                        if len(new_path) > 0:
-                            c = (off_set + nbr_track_loop, x[2 * x_loop] + 1)
-                            make_switch_e_w(width, height, grid_map, c)
-                            c = (off_set + nbr_track_loop, x[2 * x_loop + 1] + 1)
-                            make_switch_w_e(width, height, grid_map, c)
-
-                            add_pos = (int((start[0] + goal[0]) / 2), int((start[1] + goal[1]) / 2))
-                            if off_set_loop % 2 == 0:
-                                agents_positions_forward.append(add_pos)
-                                agents_directions_forward.append(([1, 3][nbr_track_loop % 2]))
-                                idx_forward.append(idx_target)
-                            else:
-                                agents_positions_backward.append(add_pos)
-                                agents_directions_backward.append(([3, 1][nbr_track_loop % 2]))
-                                idx_backward.append(idx_target)
-
-                            add_pos = (int((start[0] + goal[0]) / 2), int((2 * start[1] + goal[1]) / 3), idx_target)
-                            agents_targets.append(add_pos)
-                            idx_target += 1
-
-            # add dead-end
-            if True:
-                n = max_add_dead_end#int(np.random.choice(np.arange(max_add_dead_end-2)+1, 1)[0])
-                for pos_y in np.random.choice(np.arange(width - 7) + 3, n):
-                    pos_x = off_set
-                    pos_x1 = max(0, min(pos_x + 1, height - 1))
-                    if np.random.random() > 0.5:
-                        if pos_x + 1 < height - 1:
-                            start_track = (pos_x1, pos_y)
-                            goal_track = (pos_x1, pos_y + 1)
-                            ok = True
-                            for k in range(4):
-                                ok &= grid_map.grid[pos_x1][pos_y + (k-1) ] == 0
-                            if ok:
-                                new_path = connect_rail(rail_trans, rail_array, start_track, goal_track)
-                                if len(new_path) > 0:
-                                    c = (pos_x1 - 1, pos_y + 1)
-                                    make_switch_e_w(width, height, grid_map, c)
-                                    add_pos = goal_track  # (int((start_track[0] + goal_track[0]) / 2), int((start_track[1] + goal_track[1]) / 2))
-                                    agents_positions_forward.append(add_pos)
-                                    agents_directions_forward.append(3)
-                                    idx_forward.append(idx_target)
-                                    agents_targets.append((goal_track[0], goal_track[1], idx_target))
-                                    idx_target += 1
+                    d = np.arange(x[2 * x_loop] + 1, x[2 * x_loop + 1] - 1, 2)
+                    data.extend(d)
+
+                    new_path = connect_rail(rail_trans, rail_array, start, goal)
+                    if len(new_path) > 0:
+                        c = (off_set + nbr_track_loop, x[2 * x_loop] + 1)
+                        make_switch_e_w(width, height, grid_map, c)
+                        c = (off_set + nbr_track_loop, x[2 * x_loop + 1] + 1)
+                        make_switch_w_e(width, height, grid_map, c)
+
+                    add_pos = (int((start[0] + goal[0]) / 2), int((start[1] + goal[1]) / 2))
+                    if nbr_track_loop % 2 == 0:
+                        agents_positions_forward.append(add_pos)
+                        agents_directions_forward.append(([1, 3][off_set_loop % 2]))
+                        idx_forward.append(idx_target)
                     else:
-                        pos_x = max(0, min(pos_x + 1, height - 1))
-                        if pos_x + 1 < height - 1:
-                            start_track = (pos_x1, pos_y - 1)
-                            goal_track = (pos_x1, pos_y - 2)
-                            ok = True
-                            for k in range(4):
-                                ok &= grid_map.grid[pos_x1][pos_y - k] == 0
-                            if ok:
-                                new_path = connect_rail(rail_trans, rail_array, start_track, goal_track)
-                                if len(new_path) > 0:
-                                    c = (pos_x1 - 1, pos_y)
-                                    make_switch_w_e(width, height, grid_map, c)
-                                    add_pos = goal_track  # (int((start_track[0] + goal_track[0]) / 2), int((start_track[1] + goal_track[1]) / 2))
-                                    agents_positions_backward.append(add_pos)
-                                    agents_directions_backward.append(1)
-                                    idx_backward.append(idx_target)
-                                    agents_targets.append((goal_track[0], goal_track[1], idx_target))
-                                    idx_target += 1
+                        agents_positions_backward.append(add_pos)
+                        agents_directions_backward.append(([1, 3][off_set_loop % 2]))
+                        idx_backward.append(idx_target)
+
+                    add_pos = (int((start[0] + goal[0]) / 2), int((2 * start[1] + goal[1]) / 3), idx_target)
+                    agents_targets.append(add_pos)
+                    idx_target += 1
 
         agents_position = []
         agents_target = []
@@ -860,7 +813,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0,max_add_dead_end = 7):
 
 
 def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstations=2, min_node_dist=20, node_radius=2,
-                          num_neighb=4, realistic_mode=False, seed=0):
+                          num_neighb=4, realistic_mode=False, enhance_intersection=False, seed=0):
     '''
 
     :param nr_train_stations:
@@ -897,6 +850,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row))
             x_positions = np.linspace(2, height - 2, nodes_per_row, dtype=int)
             y_positions = np.linspace(2, width - 2, nodes_per_col, dtype=int)
+
         for node_idx in range(num_cities + num_intersections):
             to_close = True
             tries = 0
@@ -946,7 +900,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                 delete_idx = np.where(available_cities == current_node)
 
                 available_cities = np.delete(available_cities, delete_idx, 0)
-            elif len(available_intersections) > 0 and len(available_cities) > 0:
+            elif current_node >= num_cities and len(available_cities) > 0:
                 available_nodes = available_cities
                 delete_idx = np.where(available_intersections == current_node)
                 available_intersections = np.delete(available_intersections, delete_idx, 0)
@@ -1008,6 +962,37 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                 # Check if connection was made
                 if len(connection) == 0:
                     train_stations[trainstation_node].pop(-1)
+
+        # Place passing lanes at intersections
+        # We currently place them uniformly distirbuted among all cities
+        if enhance_intersection:
+
+            for intersection in range(num_intersections):
+                intersect_x_1 = np.clip(intersection_positions[intersection][0] + np.random.randint(1, 3),
+                                        1,
+                                        height - 2)
+                intersect_y_1 = np.clip(intersection_positions[intersection][1] + np.random.randint(-3, 3),
+                                        2,
+                                        width - 2)
+                intersect_x_2 = np.clip(
+                    intersection_positions[intersection][0] + np.random.randint(-3, -1),
+                    1,
+                    height - 2)
+                intersect_y_2 = np.clip(
+                    intersection_positions[intersection][1] + np.random.randint(-3, 3),
+                    1,
+                    width - 2)
+
+                # Connect train station to the correct node
+                connect_nodes(rail_trans, rail_array, (intersect_x_1, intersect_y_1),
+                              (intersect_x_2, intersect_y_2))
+                connect_nodes(rail_trans, rail_array, intersection_positions[intersection],
+                              (intersect_x_1, intersect_y_1))
+                connect_nodes(rail_trans, rail_array, intersection_positions[intersection],
+                              (intersect_x_2, intersect_y_2))
+                grid_map.fix_transitions((intersect_x_1, intersect_y_1))
+                grid_map.fix_transitions((intersect_x_2, intersect_y_2))
+
         # Fix all nodes with illegal transition maps
         for current_node in node_positions:
             grid_map.fix_transitions(current_node)
@@ -1052,17 +1037,22 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             current_start_node = agent_start_targets_nodes[agent_idx][0]
             current_target_node = agent_start_targets_nodes[agent_idx][1]
             target_station_idx = np.random.randint(len(train_stations[current_target_node]))
-            start_station_idx = np.random.randint(len(train_stations[current_start_node]))
             target = train_stations[current_target_node][target_station_idx]
-            start = train_stations[current_start_node][start_station_idx]
             while (target[0], target[1]) in agents_target:
                 target_station_idx = np.random.randint(len(train_stations[current_target_node]))
                 target = train_stations[current_target_node][target_station_idx]
-            while (start[0], start[1]) in agents_position:
+            agents_target.append((target[0], target[1]))
+
+        for agent_idx in range(num_agents):
+            current_start_node = agent_start_targets_nodes[agent_idx][0]
+            current_target_node = agent_start_targets_nodes[agent_idx][1]
+            start_station_idx = np.random.randint(len(train_stations[current_start_node]))
+            start = train_stations[current_start_node][start_station_idx]
+
+            while (start[0], start[1]) in agents_position or (start[0], start[1]) in agents_target:
                 start_station_idx = np.random.randint(len(train_stations[current_start_node]))
                 start = train_stations[current_start_node][start_station_idx]
             agents_position.append((start[0], start[1]))
-            agents_target.append((target[0], target[1]))
 
             # Orient the agent correctly
             for orientation in range(4):
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 6333909e..47ff8e6c 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -41,7 +41,7 @@ class PILGL(GraphicsLayer):
     SELECTED_AGENT_LAYER = 4
     SELECTED_TARGET_LAYER = 5
 
-    def __init__(self, width, height, jupyter=False, screen_width=800,screen_height=600):
+    def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600):
         self.yxBase = (0, 0)
         self.linewidth = 4
         self.n_agent_colors = 1  # overridden in loadAgent
@@ -263,9 +263,9 @@ class PILGL(GraphicsLayer):
 
 
 class PILSVG(PILGL):
-    def __init__(self, width, height, jupyter=False, screen_width=800,screen_height=600):
+    def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600):
         oSuper = super()
-        oSuper.__init__(width, height, jupyter,screen_width,screen_height)
+        oSuper.__init__(width, height, jupyter, screen_width, screen_height)
 
         self.lwAgents = []
         self.agents_prev = []
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 00c1e1b9..5118af75 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -39,7 +39,8 @@ class RenderTool(object):
     theta = np.linspace(0, np.pi / 2, 5)
     arc = array([np.cos(theta), np.sin(theta)]).T  # from [1,0] to [0,1]
 
-    def __init__(self, env, gl="PILSVG", jupyter=False, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, screen_width=800,screen_height=600):
+    def __init__(self, env, gl="PILSVG", jupyter=False, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
+                 screen_width=800, screen_height=600):
         self.env = env
         self.frame_nr = 0
         self.start_time = time.time()
@@ -48,12 +49,12 @@ class RenderTool(object):
         self.agent_render_variant = agent_render_variant
 
         if gl == "PIL":
-            self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width,screen_height=screen_height)
+            self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
         elif gl == "PILSVG":
-            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width,screen_height=screen_height)
+            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
         else:
             print("[", gl, "] not found, switch to PILSVG")
-            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width,screen_height=screen_height)
+            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
 
         self.new_rail = True
         self.update_background()
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index 46675e3a..86e02c75 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -1,6 +1,3 @@
-import os
-import time
-
 import numpy as np
 
 from flatland.envs.generators import sparse_rail_generator, realistic_rail_generator
@@ -22,15 +19,6 @@ def test_realistic_rail_generator(vizualization_folder_name=None):
                                   screen_width=1600)
         env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
 
-        if vizualization_folder_name is not None:
-            env_renderer.gl.save_image(
-                os.path.join(
-                    vizualization_folder_name,
-                    "flatland_frame_{:04d}.png".format(test_loop)
-                ))
-        env_renderer.close_window()
-
-
 def test_sparse_rail_generator():
     env = RailEnv(width=50,
                   height=50,
@@ -39,15 +27,13 @@ def test_sparse_rail_generator():
                                                        num_trainstations=50,  # Number of possible start/targets on map
                                                        min_node_dist=6,  # Minimal distance of nodes
                                                        node_radius=3,  # Proximity of stations to city center
-                                                       num_neighb=4,  # Number of connections to other cities
+                                                       num_neighb=3,  # Number of connections to other cities
                                                        seed=5,  # Random seed
+                                                       realistic_mode=True  # Ordered distribution of nodes
                                                        ),
-                  number_of_agents=45,
+                  number_of_agents=10,
                   obs_builder_object=GlobalObsForRailEnv())
     # reset to initialize agents_static
     env_renderer = RenderTool(env, gl="PILSVG", )
     env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
-    time.sleep(2)
-
 
-test_realistic_rail_generator(vizualization_folder_name="./rendering")
-- 
GitLab