From 0b4c3f901d233abb5bc2d7d83c221cd2a4736d43 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 17 Sep 2019 22:10:33 +0200
Subject: [PATCH] v0.2 realistic generator

---
 .../Simple_Realistic_Railway_Generator.py     | 30 +++++++++++++++----
 1 file changed, 25 insertions(+), 5 deletions(-)

diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py
index ec47bdc5..082e9b67 100644
--- a/examples/Simple_Realistic_Railway_Generator.py
+++ b/examples/Simple_Realistic_Railway_Generator.py
@@ -12,7 +12,7 @@ from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct
 from flatland.envs.schedule_generators import sparse_schedule_generator
-from flatland.utils.rendertools import AgentRenderVariant, RenderTool
+from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 
 FloatArrayType = []
 
@@ -124,8 +124,9 @@ def realistic_rail_generator(num_cities=5,
                     end_nodes_added[city_loop].append(end_node)
 
                     # place in the center of path a station slot
-                    #station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))])
-                    station_slots[city_loop].extend(connection)
+                    # station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))])
+                    for c_loop in range(len(connection)):
+                        station_slots[city_loop].append(connection[c_loop])
                     station_slots_cnt += len(connection)
 
                     station_tracks[city_loop][track_id] = connection
@@ -139,7 +140,8 @@ def realistic_rail_generator(num_cities=5,
 
         return nodes_added, station_slots, start_nodes_added, end_nodes_added, station_tracks
 
-    def create_switches_at_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
+    def create_switches_at_stations(rail_trans: RailEnvTransitions,
+                                    grid_map: GridTransitionMap,
                                     station_tracks: IntVector2DArrayType,
                                     nodes_added: IntVector2DArrayType,
                                     intern_nbr_of_switches_per_station_track: int) -> IntVector2DArrayType:
@@ -361,6 +363,20 @@ def realistic_rail_generator(num_cities=5,
                     if print_out_info:
                         print("connect_random_stations : connect_nodes -> no path found")
 
+    def remove_switch_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
+                               train_stations: IntVector2DArrayType):
+        tmp_train_stations = copy.deepcopy(train_stations)
+        for city_loop in range(len(train_stations)):
+            for n in tmp_train_stations[city_loop]:
+                do_remove = True
+                trans = rail_trans.transition_list[1]
+                for _ in range(4):
+                    trans = rail_trans.rotate_transition(trans, rotation=90)
+                    if grid_map.grid[n] == trans:
+                        do_remove = False
+                if do_remove:
+                    train_stations[city_loop].remove(n)
+
     def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
@@ -435,6 +451,10 @@ def realistic_rail_generator(num_cities=5,
         for i in range(len(nodes_added)):
             grid_map.fix_transitions(nodes_added[i])
 
+        # ----------------------------------------------------------------------------------
+        # remove stations where rail is a switch
+        remove_switch_stations(rail_trans, grid_map, train_stations)
+
         # ----------------------------------------------------------------------------------
         # Slot availability in node
         node_available_start = []
@@ -481,7 +501,7 @@ def realistic_rail_generator(num_cities=5,
 
 
 if os.path.exists("./../render_output/"):
-    for itrials in np.arange(1,1000,1):
+    for itrials in np.arange(1, 1000, 1):
         print(itrials, "generate new city")
         np.random.seed(itrials)
         env = RailEnv(width=40 + np.random.choice(100),
-- 
GitLab