From 3ce07cf2782661cd0c511bea8a398bac49c00e08 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 17 Sep 2019 15:38:56 +0200 Subject: [PATCH] update --- .../Simple_Realistic_Railway_Generator.py | 89 +++++++++++++------ flatland/core/grid/grid4_astar.py | 22 ++++- 2 files changed, 82 insertions(+), 29 deletions(-) diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index 09f3b538..c9c3bdc6 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -1,5 +1,6 @@ import copy import os +import time import warnings import numpy as np @@ -138,25 +139,58 @@ def realistic_rail_generator(num_cities=5, def create_switches_at_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, station_tracks: IntVector2DArrayType, - nodes_added: IntVector2DArrayType) -> IntVector2DArrayType: - for city_loop in range(len(station_tracks)): - datas = station_tracks[city_loop] - if len(datas) > 1: - a = datas[0] - if len(a) > 0: - start_node = a[np.random.choice(len(a) - 2) + 1] - for i in np.arange(1, len(datas)): - b = datas[i] - if len(b) > 2: - x = np.random.choice(len(b) - 2) + 1 - end_node = b[x] - connection = connect_rail(rail_trans, grid_map, start_node, end_node) - if len(connection) == 0: - if print_out_info: - print("create_switches_at_stations : connect_rail -> no path found") - nodes_added.append(start_node) - nodes_added.append(end_node) - start_node = b[np.random.choice(len(b) - 2) + 1] + nodes_added: IntVector2DArrayType, + intern_nbr_of_switches_per_station_track: int) -> IntVector2DArrayType: + + for k in range(intern_nbr_of_switches_per_station_track): + for city_loop in range(len(station_tracks)): + datas = station_tracks[city_loop] + if len(datas) > 1: + track = datas[0] + if len(track) > 3: + if k % 2 == 0: + x = 1 + else: + x = len(track) - 2 + + start_node = track[x] + for i in np.arange(1, len(datas)): + track = datas[i] + if len(track) > 3: + if k % 2 == 0: + x = x + 2 + if len(track) <= x: + x = 1 + else: + x = x - 2 + if x < 2: + x = len(track) - 2 + end_node = track[x] + connection = connect_rail(rail_trans, grid_map, start_node, end_node) + print(start_node, end_node, "-->", connection) + if len(connection) == 0: + if print_out_info: + print("create_switches_at_stations : connect_rail -> no path found") + if len(datas[i-1])>0: + start_node = datas[i-1][0] + end_node = datas[i][0] + connection = connect_rail(rail_trans, grid_map, start_node, end_node) + + + nodes_added.append(start_node) + nodes_added.append(end_node) + + + + if k % 2 == 0: + x = x + 2 + if len(track) <= x: + x = 1 + else: + x = x - 2 + if x < 2: + x = len(track) - 2 + start_node = track[x] return nodes_added @@ -382,12 +416,13 @@ def realistic_rail_generator(num_cities=5, # build switches # TODO remove true/false block if True: - create_switches_at_stations(rail_trans, grid_map, station_tracks, nodes_added) + create_switches_at_stations(rail_trans, grid_map, station_tracks, nodes_added, + intern_nbr_of_switches_per_station_track) # ---------------------------------------------------------------------------------- # connect stations # TODO remove true/false block - if True: + if False: if do_random_connect_stations: connect_random_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, intern_connect_max_nbr_of_shortes_city) @@ -452,19 +487,19 @@ if os.path.exists("./../render_output/"): np.random.seed(itrials) env = RailEnv(width=40 + np.random.choice(100), height=40 + np.random.choice(100), - rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10), + rail_generator=realistic_rail_generator(num_cities=1000, city_size=10 + np.random.choice(10), - allowed_rotation_angles=[0], - max_number_of_station_tracks=np.random.choice(4) + 4, - nbr_of_switches_per_station_track=np.random.choice(4) + 2, - connect_max_nbr_of_shortes_city=2, + allowed_rotation_angles=np.arange(-180, 180, 15), + max_number_of_station_tracks=1 + np.random.choice(4), + nbr_of_switches_per_station_track=2, + connect_max_nbr_of_shortes_city=2 + np.random.choice(4), do_random_connect_stations=False, # Number of cities in map seed=itrials, # Random seed print_out_info=True ), schedule_generator=sparse_schedule_generator(), - number_of_agents=100, + number_of_agents=0, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index 779ee984..91fec9a9 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -1,3 +1,6 @@ +import numpy as np +from matplotlib import pyplot as plt + from flatland.core.grid.grid_utils import IntVector2D from flatland.core.grid.grid_utils import IntVector2DArrayType from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d @@ -31,12 +34,16 @@ class AStarNode: def a_star(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D) -> IntVector2DArrayType: + start: IntVector2D, end: IntVector2D, + a_star_distance_function=Vec2d.get_manhattan_distance) -> IntVector2DArrayType: """ Returns a list of tuples as a path from the given start to end. If no path is found, returns path to closest point to end. """ rail_shape = grid_map.grid.shape + + tmp = np.zeros(rail_shape) - 10 + start_node = AStarNode(None, start) end_node = AStarNode(None, end) open_nodes = set() @@ -64,6 +71,14 @@ def a_star(rail_trans: RailEnvTransitions, while current is not None: path.append(current.pos) current = current.parent + + if False: + plt.ion() + plt.clf() + plt.imshow(tmp, interpolation='nearest') + plt.draw() + plt.pause(1e-17) + # return reversed path return path[::-1] @@ -73,6 +88,7 @@ def a_star(rail_trans: RailEnvTransitions, prev_pos = current_node.parent.pos else: prev_pos = None + for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: # update the "current" pos node_pos = Vec2d.add(current_node.pos, new_pos) @@ -98,9 +114,11 @@ def a_star(rail_trans: RailEnvTransitions, # create the f, g, and h values child.g = current_node.g + 1.0 # this heuristic avoids diagonal paths - child.h = Vec2d.get_manhattan_distance(child.pos, end_node.pos) + child.h = a_star_distance_function(child.pos, end_node.pos) child.f = child.g + child.h + tmp[child.pos[0]][child.pos[1]] = child.f + # already in the open list? if child in open_nodes: continue -- GitLab