Skip to content
Snippets Groups Projects
Commit 2b748303 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

refactoring and clean up

parent 34df9cab
No related branches found
No related tags found
No related merge requests found
......@@ -5,8 +5,7 @@ import warnings
import numpy as np
from flatland.core.grid.grid4_utils import mirror
from flatland.core.grid.grid_utils import Vec2dOperations as vec2d
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail
......@@ -17,62 +16,9 @@ from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
# TODO : remove (reuse existing code!!)
class GripMapOp:
def min_max_cut(min_v, max_v, v):
return max(min_v, min(max_v, v))
def add_rail(width, height, grid_map, pt_from, pt_via, pt_to, bAddRemove=True):
gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) # NESW in RC
lrcStroke = [[GripMapOp.min_max_cut(0, height - 1, pt_from[0]),
GripMapOp.min_max_cut(0, width - 1, pt_from[1])],
[GripMapOp.min_max_cut(0, height - 1, pt_via[0]),
GripMapOp.min_max_cut(0, width - 1, pt_via[1])],
[GripMapOp.min_max_cut(0, height - 1, pt_to[0]),
GripMapOp.min_max_cut(0, width - 1, pt_to[1])]]
rc3Cells = np.array(lrcStroke[:3]) # the 3 cells
rcMiddle = rc3Cells[1] # the middle cell which we will update
bDeadend = np.all(lrcStroke[0] == lrcStroke[2]) # deadend means cell 0 == cell 2
# get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East
rc2Trans = np.diff(rc3Cells, axis=0)
# get the direction index for the 2 transitions
liTrans = []
for rcTrans in rc2Trans:
# gRCTrans - rcTrans gives an array of vector differences between our rcTrans
# and the 4 directions stored in gRCTrans.
# Where the vector difference is zero, we have a match...
# np.all detects where the whole row,col vector is zero.
# argwhere gives the index of the zero vector, ie the direction index
iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1))
if len(iTrans) > 0:
iTrans = iTrans[0][0]
liTrans.append(iTrans)
# check that we have two transitions
if len(liTrans) == 2:
# Set the transition
# Set the transition
# If this transition spans 3 cells, it is not a deadend, so remove any deadends.
# The user will need to resolve any conflicts.
grid_map.set_transition((*rcMiddle, liTrans[0]),
liTrans[1],
bAddRemove,
remove_deadends=not bDeadend)
# Also set the reverse transition
# use the reversed outbound transition for inbound
# and the reversed inbound transition for outbound
grid_map.set_transition((*rcMiddle, mirror(liTrans[1])),
mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
def realistic_rail_generator(num_cities=5,
city_size=10,
allowed_rotation_angles=[0, 90],
allowed_rotation_angles=None,
max_number_of_station_tracks=4,
nbr_of_switches_per_station_track=2,
connect_max_nbr_of_shortes_city=4,
......@@ -82,6 +28,7 @@ def realistic_rail_generator(num_cities=5,
"""
This is a level generator which generates a realistic rail configurations
:param print_out_info:
:param num_cities: Number of city node
:param city_size: Length of city measure in cells
:param allowed_rotation_angles: Rotate the city (around center)
......@@ -117,25 +64,25 @@ def realistic_rail_generator(num_cities=5,
generate_city_locations = [[(int(xs[i]), int(ys[i])), (int(xs[i]), int(ys[i]))] for i in range(len(xs))]
return generate_city_locations, max_num_cities
def do_orient_cities(generate_city_locations, intern_city_size, allowed_rotation_angles):
def do_orient_cities(generate_city_locations, intern_city_size, rotation_angles_set):
for i in range(len(generate_city_locations)):
# station main orientation (horizontal or vertical
rot_angle = np.random.choice(allowed_rotation_angles)
add_pos_val = vec2d.scale_pos(vec2d.rotate_pos((1, 0), rot_angle),
rot_angle = np.random.choice(rotation_angles_set)
add_pos_val = Vec2d.scale_pos(Vec2d.rotate_pos((1, 0), rot_angle),
(max(1, (intern_city_size - 3) / 2)))
generate_city_locations[i][0] = vec2d.add_pos(generate_city_locations[i][1], add_pos_val)
add_pos_val = vec2d.scale_pos(vec2d.rotate_pos((1, 0), 180 + rot_angle),
generate_city_locations[i][0] = Vec2d.add_pos(generate_city_locations[i][1], add_pos_val)
add_pos_val = Vec2d.scale_pos(Vec2d.rotate_pos((1, 0), 180 + rot_angle),
(max(1, (intern_city_size - 3) / 2)))
generate_city_locations[i][1] = vec2d.add_pos(generate_city_locations[i][1], add_pos_val)
generate_city_locations[i][1] = Vec2d.add_pos(generate_city_locations[i][1], add_pos_val)
return generate_city_locations
def create_stations_from_city_locations(rail_trans, rail_array, generate_city_locations,
intern_max_number_of_station_tracks):
nodes_added = []
start_nodes_added = [[] for i in range(len(generate_city_locations))]
end_nodes_added = [[] for i in range(len(generate_city_locations))]
station_slots = [[] for i in range(len(generate_city_locations))]
station_tracks = [[[] for j in range(intern_max_number_of_station_tracks)] for i in range(len(
start_nodes_added = [[] for _ in range(len(generate_city_locations))]
end_nodes_added = [[] for _ in range(len(generate_city_locations))]
station_slots = [[] for _ in range(len(generate_city_locations))]
station_tracks = [[[] for _ in range(intern_max_number_of_station_tracks)] for _ in range(len(
generate_city_locations))]
station_slots_cnt = 0
......@@ -147,13 +94,13 @@ def realistic_rail_generator(num_cities=5,
org_start_node = generate_city_locations[city_loop][0]
org_end_node = generate_city_locations[city_loop][1]
ortho_trans = vec2d.make_orthogonal_pos(
vec2d.normalize_pos(vec2d.subtract_pos(org_start_node, org_end_node)))
ortho_trans = Vec2d.make_orthogonal_pos(
Vec2d.normalize_pos(Vec2d.subtract_pos(org_start_node, org_end_node)))
s = (ct - number_of_connecting_tracks / 2.0)
start_node = vec2d.ceil_pos(
vec2d.add_pos(org_start_node, vec2d.scale_pos(ortho_trans, s)))
end_node = vec2d.ceil_pos(
vec2d.add_pos(org_end_node, vec2d.scale_pos(ortho_trans, s)))
start_node = Vec2d.ceil_pos(
Vec2d.add_pos(org_start_node, Vec2d.scale_pos(ortho_trans, s)))
end_node = Vec2d.ceil_pos(
Vec2d.add_pos(org_end_node, Vec2d.scale_pos(ortho_trans, s)))
connection = connect_from_nodes(rail_trans, rail_array, start_node, end_node)
if len(connection) > 0:
......@@ -174,8 +121,7 @@ def realistic_rail_generator(num_cities=5,
return nodes_added, station_slots, start_nodes_added, end_nodes_added, station_tracks
def create_switches_at_stations(rail_trans, rail_array, width, height, grid_map, station_tracks, nodes_added,
intern_nbr_of_switches_per_station_track):
def create_switches_at_stations(rail_trans, rail_array, station_tracks, nodes_added):
for city_loop in range(len(station_tracks)):
datas = station_tracks[city_loop]
......@@ -183,13 +129,14 @@ def realistic_rail_generator(num_cities=5,
a = datas[0]
if len(a) > 0:
start_node = a[np.random.choice(len(a) - 2) + 1]
b = []
for i in np.arange(1, len(datas)):
b = datas[i]
if len(b) > 2:
x = np.random.choice(len(b) - 2) + 1
end_node = b[x]
connection = connect_rail(rail_trans, rail_array, start_node, end_node)
if len(connection) == 0:
print("create_switches_at_stations : connect_rail -> no path found")
nodes_added.append(start_node)
nodes_added.append(end_node)
start_node = b[np.random.choice(len(b) - 2) + 1]
......@@ -230,8 +177,8 @@ def realistic_rail_generator(num_cities=5,
if len(graphids) > 0:
for i in range(len(graphids) - 1):
connection = []
cnt = 0
while len(connection) == 0 and cnt < 100:
iteration_counter = 0
while len(connection) == 0 and iteration_counter < 100:
s_nodes = copy.deepcopy(org_s_nodes)
e_nodes = copy.deepcopy(org_e_nodes)
start_nodes = s_nodes[graphids[i]]
......@@ -247,7 +194,7 @@ def realistic_rail_generator(num_cities=5,
if len(connection) > 0:
nodes_added.append(start_node)
nodes_added.append(end_node)
cnt += 1
iteration_counter += 1
def connect_stations(rail_trans, rail_array, org_s_nodes, org_e_nodes, nodes_added,
inter_connect_max_nbr_of_shortes_city):
......@@ -263,12 +210,13 @@ def realistic_rail_generator(num_cities=5,
for start_node in sns:
min_distance = np.inf
end_node = None
cl = 0
for city_loop_find_shortest in range(len(e_nodes)):
if city_loop_find_shortest == city_loop:
continue
ens = e_nodes[city_loop_find_shortest]
for en in ens:
d = vec2d.get_norm_pos(vec2d.subtract_pos(en, start_node))
d = Vec2d.get_norm_pos(Vec2d.subtract_pos(en, start_node))
if d < min_distance:
min_distance = d
end_node = en
......@@ -336,8 +284,8 @@ def realistic_rail_generator(num_cities=5,
for i in range(max_input_output):
start_node = s_nodes[idx_s_nodes[i]]
end_node = e_nodes[idx_e_nodes[i]]
new_trans = rail_array[start_node] = 0
new_trans = rail_array[end_node] = 0
rail_array[start_node] = 0
rail_array[end_node] = 0
connection = connect_nodes(rail_trans, rail_array, start_node, end_node)
if len(connection) > 0:
nodes_added.append(start_node)
......@@ -398,8 +346,7 @@ def realistic_rail_generator(num_cities=5,
# build switches
# TODO remove true/false block
if True:
create_switches_at_stations(rail_trans, rail_array, width, height, grid_map, station_tracks, nodes_added,
intern_nbr_of_switches_per_station_track)
create_switches_at_stations(rail_trans, rail_array, station_tracks, nodes_added)
# ----------------------------------------------------------------------------------
# connect stations
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment