Skip to content
Snippets Groups Projects
Commit 3ce07cf2 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

update

parent 9babe928
No related branches found
No related tags found
No related merge requests found
import copy
import os
import time
import warnings
import numpy as np
......@@ -138,25 +139,58 @@ def realistic_rail_generator(num_cities=5,
def create_switches_at_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
station_tracks: IntVector2DArrayType,
nodes_added: IntVector2DArrayType) -> IntVector2DArrayType:
for city_loop in range(len(station_tracks)):
datas = station_tracks[city_loop]
if len(datas) > 1:
a = datas[0]
if len(a) > 0:
start_node = a[np.random.choice(len(a) - 2) + 1]
for i in np.arange(1, len(datas)):
b = datas[i]
if len(b) > 2:
x = np.random.choice(len(b) - 2) + 1
end_node = b[x]
connection = connect_rail(rail_trans, grid_map, start_node, end_node)
if len(connection) == 0:
if print_out_info:
print("create_switches_at_stations : connect_rail -> no path found")
nodes_added.append(start_node)
nodes_added.append(end_node)
start_node = b[np.random.choice(len(b) - 2) + 1]
nodes_added: IntVector2DArrayType,
intern_nbr_of_switches_per_station_track: int) -> IntVector2DArrayType:
for k in range(intern_nbr_of_switches_per_station_track):
for city_loop in range(len(station_tracks)):
datas = station_tracks[city_loop]
if len(datas) > 1:
track = datas[0]
if len(track) > 3:
if k % 2 == 0:
x = 1
else:
x = len(track) - 2
start_node = track[x]
for i in np.arange(1, len(datas)):
track = datas[i]
if len(track) > 3:
if k % 2 == 0:
x = x + 2
if len(track) <= x:
x = 1
else:
x = x - 2
if x < 2:
x = len(track) - 2
end_node = track[x]
connection = connect_rail(rail_trans, grid_map, start_node, end_node)
print(start_node, end_node, "-->", connection)
if len(connection) == 0:
if print_out_info:
print("create_switches_at_stations : connect_rail -> no path found")
if len(datas[i-1])>0:
start_node = datas[i-1][0]
end_node = datas[i][0]
connection = connect_rail(rail_trans, grid_map, start_node, end_node)
nodes_added.append(start_node)
nodes_added.append(end_node)
if k % 2 == 0:
x = x + 2
if len(track) <= x:
x = 1
else:
x = x - 2
if x < 2:
x = len(track) - 2
start_node = track[x]
return nodes_added
......@@ -382,12 +416,13 @@ def realistic_rail_generator(num_cities=5,
# build switches
# TODO remove true/false block
if True:
create_switches_at_stations(rail_trans, grid_map, station_tracks, nodes_added)
create_switches_at_stations(rail_trans, grid_map, station_tracks, nodes_added,
intern_nbr_of_switches_per_station_track)
# ----------------------------------------------------------------------------------
# connect stations
# TODO remove true/false block
if True:
if False:
if do_random_connect_stations:
connect_random_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added,
intern_connect_max_nbr_of_shortes_city)
......@@ -452,19 +487,19 @@ if os.path.exists("./../render_output/"):
np.random.seed(itrials)
env = RailEnv(width=40 + np.random.choice(100),
height=40 + np.random.choice(100),
rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10),
rail_generator=realistic_rail_generator(num_cities=1000,
city_size=10 + np.random.choice(10),
allowed_rotation_angles=[0],
max_number_of_station_tracks=np.random.choice(4) + 4,
nbr_of_switches_per_station_track=np.random.choice(4) + 2,
connect_max_nbr_of_shortes_city=2,
allowed_rotation_angles=np.arange(-180, 180, 15),
max_number_of_station_tracks=1 + np.random.choice(4),
nbr_of_switches_per_station_track=2,
connect_max_nbr_of_shortes_city=2 + np.random.choice(4),
do_random_connect_stations=False,
# Number of cities in map
seed=itrials, # Random seed
print_out_info=True
),
schedule_generator=sparse_schedule_generator(),
number_of_agents=100,
number_of_agents=0,
obs_builder_object=GlobalObsForRailEnv())
# reset to initialize agents_static
......
import numpy as np
from matplotlib import pyplot as plt
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.grid.grid_utils import IntVector2DArrayType
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
......@@ -31,12 +34,16 @@ class AStarNode:
def a_star(rail_trans: RailEnvTransitions,
grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D) -> IntVector2DArrayType:
start: IntVector2D, end: IntVector2D,
a_star_distance_function=Vec2d.get_manhattan_distance) -> IntVector2DArrayType:
"""
Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end.
"""
rail_shape = grid_map.grid.shape
tmp = np.zeros(rail_shape) - 10
start_node = AStarNode(None, start)
end_node = AStarNode(None, end)
open_nodes = set()
......@@ -64,6 +71,14 @@ def a_star(rail_trans: RailEnvTransitions,
while current is not None:
path.append(current.pos)
current = current.parent
if False:
plt.ion()
plt.clf()
plt.imshow(tmp, interpolation='nearest')
plt.draw()
plt.pause(1e-17)
# return reversed path
return path[::-1]
......@@ -73,6 +88,7 @@ def a_star(rail_trans: RailEnvTransitions,
prev_pos = current_node.parent.pos
else:
prev_pos = None
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
# update the "current" pos
node_pos = Vec2d.add(current_node.pos, new_pos)
......@@ -98,9 +114,11 @@ def a_star(rail_trans: RailEnvTransitions,
# create the f, g, and h values
child.g = current_node.g + 1.0
# this heuristic avoids diagonal paths
child.h = Vec2d.get_manhattan_distance(child.pos, end_node.pos)
child.h = a_star_distance_function(child.pos, end_node.pos)
child.f = child.g + child.h
tmp[child.pos[0]][child.pos[1]] = child.f
# already in the open list?
if child in open_nodes:
continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment