From 3ce07cf2782661cd0c511bea8a398bac49c00e08 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 17 Sep 2019 15:38:56 +0200
Subject: [PATCH] update

---
 .../Simple_Realistic_Railway_Generator.py     | 89 +++++++++++++------
 flatland/core/grid/grid4_astar.py             | 22 ++++-
 2 files changed, 82 insertions(+), 29 deletions(-)

diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py
index 09f3b538..c9c3bdc6 100644
--- a/examples/Simple_Realistic_Railway_Generator.py
+++ b/examples/Simple_Realistic_Railway_Generator.py
@@ -1,5 +1,6 @@
 import copy
 import os
+import time
 import warnings
 
 import numpy as np
@@ -138,25 +139,58 @@ def realistic_rail_generator(num_cities=5,
 
     def create_switches_at_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
                                     station_tracks: IntVector2DArrayType,
-                                    nodes_added: IntVector2DArrayType) -> IntVector2DArrayType:
-        for city_loop in range(len(station_tracks)):
-            datas = station_tracks[city_loop]
-            if len(datas) > 1:
-                a = datas[0]
-                if len(a) > 0:
-                    start_node = a[np.random.choice(len(a) - 2) + 1]
-                    for i in np.arange(1, len(datas)):
-                        b = datas[i]
-                        if len(b) > 2:
-                            x = np.random.choice(len(b) - 2) + 1
-                            end_node = b[x]
-                            connection = connect_rail(rail_trans, grid_map, start_node, end_node)
-                            if len(connection) == 0:
-                                if print_out_info:
-                                    print("create_switches_at_stations : connect_rail -> no path found")
-                            nodes_added.append(start_node)
-                            nodes_added.append(end_node)
-                            start_node = b[np.random.choice(len(b) - 2) + 1]
+                                    nodes_added: IntVector2DArrayType,
+                                    intern_nbr_of_switches_per_station_track: int) -> IntVector2DArrayType:
+
+        for k in range(intern_nbr_of_switches_per_station_track):
+            for city_loop in range(len(station_tracks)):
+                datas = station_tracks[city_loop]
+                if len(datas) > 1:
+                    track = datas[0]
+                    if len(track) > 3:
+                        if k % 2 == 0:
+                            x = 1
+                        else:
+                            x = len(track) - 2
+
+                        start_node = track[x]
+                        for i in np.arange(1, len(datas)):
+                            track = datas[i]
+                            if len(track) > 3:
+                                if k % 2 == 0:
+                                    x = x + 2
+                                    if len(track) <= x:
+                                        x = 1
+                                else:
+                                    x = x - 2
+                                    if x < 2:
+                                        x = len(track) - 2
+                                end_node = track[x]
+                                connection = connect_rail(rail_trans, grid_map, start_node, end_node)
+                                print(start_node, end_node, "-->", connection)
+                                if len(connection) == 0:
+                                    if print_out_info:
+                                        print("create_switches_at_stations : connect_rail -> no path found")
+                                        if len(datas[i-1])>0:
+                                            start_node = datas[i-1][0]
+                                        end_node = datas[i][0]
+                                        connection = connect_rail(rail_trans, grid_map, start_node, end_node)
+
+
+                                nodes_added.append(start_node)
+                                nodes_added.append(end_node)
+
+
+
+                                if k % 2 == 0:
+                                    x = x + 2
+                                    if len(track) <= x:
+                                        x = 1
+                                else:
+                                    x = x - 2
+                                    if x < 2:
+                                        x = len(track) - 2
+                                start_node = track[x]
 
         return nodes_added
 
@@ -382,12 +416,13 @@ def realistic_rail_generator(num_cities=5,
         # build switches
         # TODO remove true/false block
         if True:
-            create_switches_at_stations(rail_trans, grid_map, station_tracks, nodes_added)
+            create_switches_at_stations(rail_trans, grid_map, station_tracks, nodes_added,
+                                        intern_nbr_of_switches_per_station_track)
 
         # ----------------------------------------------------------------------------------
         # connect stations
         # TODO remove true/false block
-        if True:
+        if False:
             if do_random_connect_stations:
                 connect_random_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added,
                                         intern_connect_max_nbr_of_shortes_city)
@@ -452,19 +487,19 @@ if os.path.exists("./../render_output/"):
         np.random.seed(itrials)
         env = RailEnv(width=40 + np.random.choice(100),
                       height=40 + np.random.choice(100),
-                      rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10),
+                      rail_generator=realistic_rail_generator(num_cities=1000,
                                                               city_size=10 + np.random.choice(10),
-                                                              allowed_rotation_angles=[0],
-                                                              max_number_of_station_tracks=np.random.choice(4) + 4,
-                                                              nbr_of_switches_per_station_track=np.random.choice(4) + 2,
-                                                              connect_max_nbr_of_shortes_city=2,
+                                                              allowed_rotation_angles=np.arange(-180, 180, 15),
+                                                              max_number_of_station_tracks=1 + np.random.choice(4),
+                                                              nbr_of_switches_per_station_track=2,
+                                                              connect_max_nbr_of_shortes_city=2 + np.random.choice(4),
                                                               do_random_connect_stations=False,
                                                               # Number of cities in map
                                                               seed=itrials,  # Random seed
                                                               print_out_info=True
                                                               ),
                       schedule_generator=sparse_schedule_generator(),
-                      number_of_agents=100,
+                      number_of_agents=0,
                       obs_builder_object=GlobalObsForRailEnv())
 
         # reset to initialize agents_static
diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py
index 779ee984..91fec9a9 100644
--- a/flatland/core/grid/grid4_astar.py
+++ b/flatland/core/grid/grid4_astar.py
@@ -1,3 +1,6 @@
+import numpy as np
+from matplotlib import pyplot as plt
+
 from flatland.core.grid.grid_utils import IntVector2D
 from flatland.core.grid.grid_utils import IntVector2DArrayType
 from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
@@ -31,12 +34,16 @@ class AStarNode:
 
 def a_star(rail_trans: RailEnvTransitions,
            grid_map: GridTransitionMap,
-           start: IntVector2D, end: IntVector2D) -> IntVector2DArrayType:
+           start: IntVector2D, end: IntVector2D,
+           a_star_distance_function=Vec2d.get_manhattan_distance) -> IntVector2DArrayType:
     """
     Returns a list of tuples as a path from the given start to end.
     If no path is found, returns path to closest point to end.
     """
     rail_shape = grid_map.grid.shape
+
+    tmp = np.zeros(rail_shape) - 10
+
     start_node = AStarNode(None, start)
     end_node = AStarNode(None, end)
     open_nodes = set()
@@ -64,6 +71,14 @@ def a_star(rail_trans: RailEnvTransitions,
             while current is not None:
                 path.append(current.pos)
                 current = current.parent
+
+            if False:
+                plt.ion()
+                plt.clf()
+                plt.imshow(tmp, interpolation='nearest')
+                plt.draw()
+                plt.pause(1e-17)
+
             # return reversed path
             return path[::-1]
 
@@ -73,6 +88,7 @@ def a_star(rail_trans: RailEnvTransitions,
             prev_pos = current_node.parent.pos
         else:
             prev_pos = None
+
         for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
             # update the "current" pos
             node_pos = Vec2d.add(current_node.pos, new_pos)
@@ -98,9 +114,11 @@ def a_star(rail_trans: RailEnvTransitions,
             # create the f, g, and h values
             child.g = current_node.g + 1.0
             # this heuristic avoids diagonal paths
-            child.h = Vec2d.get_manhattan_distance(child.pos, end_node.pos)
+            child.h = a_star_distance_function(child.pos, end_node.pos)
             child.f = child.g + child.h
 
+            tmp[child.pos[0]][child.pos[1]] = child.f
+
             # already in the open list?
             if child in open_nodes:
                 continue
-- 
GitLab