Commit 854c9726 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

refactoring and preparation (city generator).

parent 0b4c3f90
import copy
import os
import warnings
from typing import Sequence, Optional
import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d, IntVector2DArrayType
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d, IntVector2DArray, IntVector2DDistance, \
IntVector2DArrayArray
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail
......@@ -17,19 +19,19 @@ from flatland.utils.rendertools import RenderTool, AgentRenderVariant
FloatArrayType = []
def realistic_rail_generator(num_cities=5,
city_size=10,
allowed_rotation_angles=None,
max_number_of_station_tracks=4,
nbr_of_switches_per_station_track=2,
connect_max_nbr_of_shortes_city=4,
do_random_connect_stations=False,
seed=0,
print_out_info=True) -> RailGenerator:
def realistic_rail_generator(num_cities: int = 5,
city_size: int = 10,
allowed_rotation_angles: Optional[Sequence[float]] = None,
max_number_of_station_tracks: int = 4,
nbr_of_switches_per_station_track: int = 2,
connect_max_nbr_of_shortes_city: int = 4,
do_random_connect_stations: bool = False,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
seed: int = 0,
print_out_info: bool = True) -> RailGenerator:
"""
This is a level generator which generates a realistic rail configurations
:param print_out_info:
:param num_cities: Number of city node
:param city_size: Length of city measure in cells
:param allowed_rotation_angles: Rotate the city (around center)
......@@ -37,8 +39,9 @@ def realistic_rail_generator(num_cities=5,
:param nbr_of_switches_per_station_track: number of switches per track (max)
:param connect_max_nbr_of_shortes_city: max number of connecting track between stations
:param do_random_connect_stations : if false connect the stations along the grid (top,left -> down,right), else rand
:param a_star_distance_function: Heuristic how the distance between two nodes get estimated in the "a-star" path
:param seed: Random Seed
:print_out_info : print debug info
:param print_out_info: print debug info if True
:return:
-------
numpy.ndarray of type numpy.uint16
......@@ -48,7 +51,7 @@ def realistic_rail_generator(num_cities=5,
def do_generate_city_locations(width: int,
height: int,
intern_city_size: int,
intern_max_number_of_station_tracks: int) -> (IntVector2DArrayType, int):
intern_max_number_of_station_tracks: int) -> (IntVector2DArray, int):
X = int(np.floor(max(1, height - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size))
Y = int(np.floor(max(1, width - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size))
......@@ -68,7 +71,7 @@ def realistic_rail_generator(num_cities=5,
generate_city_locations = [[(int(xs[i]), int(ys[i])), (int(xs[i]), int(ys[i]))] for i in range(len(xs))]
return generate_city_locations, max_num_cities
def do_orient_cities(generate_city_locations: IntVector2DArrayType, intern_city_size: int,
def do_orient_cities(generate_city_locations: IntVector2DArrayArray, intern_city_size: int,
rotation_angles_set: FloatArrayType):
for i in range(len(generate_city_locations)):
# station main orientation (horizontal or vertical
......@@ -83,12 +86,12 @@ def realistic_rail_generator(num_cities=5,
def create_stations_from_city_locations(rail_trans: RailEnvTransitions,
grid_map: GridTransitionMap,
generate_city_locations: IntVector2DArrayType,
intern_max_number_of_station_tracks: int) -> (IntVector2DArrayType,
IntVector2DArrayType,
IntVector2DArrayType,
IntVector2DArrayType,
IntVector2DArrayType):
generate_city_locations: IntVector2DArray,
intern_max_number_of_station_tracks: int) -> (IntVector2DArray,
IntVector2DArray,
IntVector2DArray,
IntVector2DArray,
IntVector2DArray):
nodes_added = []
start_nodes_added = [[] for _ in range(len(generate_city_locations))]
......@@ -115,7 +118,7 @@ def realistic_rail_generator(num_cities=5,
end_node = Vec2d.ceil(
Vec2d.add(org_end_node, Vec2d.scale(ortho_trans, s)))
connection = connect_from_nodes(rail_trans, grid_map, start_node, end_node)
connection = connect_from_nodes(rail_trans, grid_map, start_node, end_node, a_star_distance_function)
if len(connection) > 0:
nodes_added.append(start_node)
nodes_added.append(end_node)
......@@ -142,9 +145,9 @@ def realistic_rail_generator(num_cities=5,
def create_switches_at_stations(rail_trans: RailEnvTransitions,
grid_map: GridTransitionMap,
station_tracks: IntVector2DArrayType,
nodes_added: IntVector2DArrayType,
intern_nbr_of_switches_per_station_track: int) -> IntVector2DArrayType:
station_tracks: IntVector2DArray,
nodes_added: IntVector2DArray,
intern_nbr_of_switches_per_station_track: int) -> IntVector2DArray:
for k_loop in range(intern_nbr_of_switches_per_station_track):
for city_loop in range(len(station_tracks)):
......@@ -170,13 +173,14 @@ def realistic_rail_generator(num_cities=5,
if x < 2:
x = len(track) - 1
end_node = track[x]
connection = connect_rail(rail_trans, grid_map, start_node, end_node)
connection = connect_rail(rail_trans, grid_map, start_node, end_node,
a_star_distance_function)
if len(connection) == 0:
if print_out_info:
print("create_switches_at_stations : connect_rail -> no path found")
start_node = datas[i][0]
end_node = datas[i - 1][0]
connect_rail(rail_trans, grid_map, start_node, end_node)
connect_rail(rail_trans, grid_map, start_node, end_node, a_star_distance_function)
nodes_added.append(start_node)
nodes_added.append(end_node)
......@@ -226,10 +230,10 @@ def realistic_rail_generator(num_cities=5,
return graph, np.unique(graph_ids).astype(int)
def connect_sub_graphs(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
org_s_nodes: IntVector2DArrayType,
org_e_nodes: IntVector2DArrayType,
city_edges: IntVector2DArrayType,
nodes_added: IntVector2DArrayType):
org_s_nodes: IntVector2DArray,
org_e_nodes: IntVector2DArray,
city_edges: IntVector2DArray,
nodes_added: IntVector2DArray):
_, graphids = calc_nbr_of_graphs(city_edges)
if len(graphids) > 0:
for i in range(len(graphids) - 1):
......@@ -247,7 +251,7 @@ def realistic_rail_generator(num_cities=5,
# TODO : will be generated.
grid_map.grid[start_node] = 0
grid_map.grid[end_node] = 0
connection = connect_rail(rail_trans, grid_map, start_node, end_node)
connection = connect_rail(rail_trans, grid_map, start_node, end_node, a_star_distance_function)
if len(connection) > 0:
nodes_added.append(start_node)
nodes_added.append(end_node)
......@@ -259,9 +263,9 @@ def realistic_rail_generator(num_cities=5,
def connect_stations(rail_trans: RailEnvTransitions,
grid_map: GridTransitionMap,
org_s_nodes: IntVector2DArrayType,
org_e_nodes: IntVector2DArrayType,
nodes_added: IntVector2DArrayType,
org_s_nodes: IntVector2DArray,
org_e_nodes: IntVector2DArray,
nodes_added: IntVector2DArray,
intern_connect_max_nbr_of_shortes_city: int):
city_edges = []
......@@ -291,7 +295,7 @@ def realistic_rail_generator(num_cities=5,
tmp_trans_en = grid_map.grid[end_node]
grid_map.grid[start_node] = 0
grid_map.grid[end_node] = 0
connection = connect_rail(rail_trans, grid_map, start_node, end_node)
connection = connect_rail(rail_trans, grid_map, start_node, end_node, a_star_distance_function)
if len(connection) > 0:
s_nodes[city_loop].remove(start_node)
e_nodes[cl].remove(end_node)
......@@ -313,9 +317,9 @@ def realistic_rail_generator(num_cities=5,
connect_sub_graphs(rail_trans, grid_map, org_s_nodes, org_e_nodes, city_edges, nodes_added)
def connect_random_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start_nodes_added: IntVector2DArrayType,
end_nodes_added: IntVector2DArrayType,
nodes_added: IntVector2DArrayType,
start_nodes_added: IntVector2DArray,
end_nodes_added: IntVector2DArray,
nodes_added: IntVector2DArray,
intern_connect_max_nbr_of_shortes_city: int):
if len(start_nodes_added) < 1:
return
......@@ -355,7 +359,7 @@ def realistic_rail_generator(num_cities=5,
end_node = e_nodes[idx_e_nodes[i]]
grid_map.grid[start_node] = 0
grid_map.grid[end_node] = 0
connection = connect_nodes(rail_trans, grid_map, start_node, end_node)
connection = connect_nodes(rail_trans, grid_map, start_node, end_node, a_star_distance_function)
if len(connection) > 0:
nodes_added.append(start_node)
nodes_added.append(end_node)
......@@ -364,7 +368,7 @@ def realistic_rail_generator(num_cities=5,
print("connect_random_stations : connect_nodes -> no path found")
def remove_switch_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
train_stations: IntVector2DArrayType):
train_stations: IntVector2DArray):
tmp_train_stations = copy.deepcopy(train_stations)
for city_loop in range(len(train_stations)):
for n in tmp_train_stations[city_loop]:
......@@ -481,7 +485,7 @@ def realistic_rail_generator(num_cities=5,
if (tries + 1) % 10 == 0:
start_node = np.random.choice(avail_start_nodes)
if tries > 100:
warnings.warn("Could not set trainstations, removing agent!")
warnings.warn("Could not set train_stations, removing agent!")
found_agent_pair = False
break
if found_agent_pair:
......@@ -508,13 +512,13 @@ if os.path.exists("./../render_output/"):
height=40 + np.random.choice(100),
rail_generator=realistic_rail_generator(num_cities=5 + np.random.choice(10),
city_size=10 + np.random.choice(5),
allowed_rotation_angles=np.arange(0, 360, 90),
max_number_of_station_tracks=1 + np.random.choice(4),
allowed_rotation_angles=np.arange(0, 360, 6),
max_number_of_station_tracks=4 + np.random.choice(4),
nbr_of_switches_per_station_track=2 + np.random.choice(2),
connect_max_nbr_of_shortes_city=2 + np.random.choice(4),
do_random_connect_stations=itrials % 2 == 0,
# Number of cities in map
seed=itrials, # Random seed
a_star_distance_function=Vec2d.get_euclidean_distance,
seed=itrials,
print_out_info=False
),
schedule_generator=sparse_schedule_generator(),
......
import numpy as np
from matplotlib import pyplot as plt
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.grid.grid_utils import IntVector2DArrayType
from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance
from flatland.core.grid.grid_utils import IntVector2DArray
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
class AStarNode:
"""A node class for A* Pathfinding"""
def __init__(self, parent: IntVector2D = None, pos: IntVector2D = None):
self.parent: IntVector2D = parent
def __init__(self, pos: IntVector2D, parent=None):
self.parent = parent
self.pos: IntVector2D = pos
self.g = 0.0
self.h = 0.0
self.f = 0.0
def __eq__(self, other: IntVector2D):
def __eq__(self, other):
"""
Parameters
----------
other : AStarNode
"""
return self.pos == other.pos
def __hash__(self):
......@@ -32,10 +36,9 @@ class AStarNode:
self.f = other.f
def a_star(rail_trans: RailEnvTransitions,
grid_map: GridTransitionMap,
def a_star(grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D,
a_star_distance_function=Vec2d.get_manhattan_distance) -> IntVector2DArrayType:
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
"""
Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end.
......@@ -44,8 +47,8 @@ def a_star(rail_trans: RailEnvTransitions,
tmp = np.zeros(rail_shape) - 10
start_node = AStarNode(None, start)
end_node = AStarNode(None, end)
start_node = AStarNode(start, None)
end_node = AStarNode(end, None)
open_nodes = set()
closed_nodes = set()
open_nodes.add(start_node)
......@@ -72,13 +75,6 @@ def a_star(rail_trans: RailEnvTransitions,
path.append(current.pos)
current = current.parent
if False:
plt.ion()
plt.clf()
plt.imshow(tmp, interpolation='nearest')
plt.draw()
plt.pause(1e-17)
# return reversed path
return path[::-1]
......@@ -91,7 +87,7 @@ def a_star(rail_trans: RailEnvTransitions,
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
# update the "current" pos
node_pos = Vec2d.add(current_node.pos, new_pos)
node_pos: IntVector2D = Vec2d.add(current_node.pos, new_pos)
# is node_pos inside the grid?
if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
......@@ -102,7 +98,7 @@ def a_star(rail_trans: RailEnvTransitions,
continue
# create new node
new_node = AStarNode(current_node, node_pos)
new_node = AStarNode(node_pos, current_node)
children.append(new_node)
# loop through children
......
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import IntVector2DArrayType
from flatland.core.grid.grid_utils import IntVector2DArray
def get_direction(pos1: IntVector2DArrayType, pos2: IntVector2DArrayType) -> Grid4TransitionsEnum:
def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4TransitionsEnum:
"""
Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions.
......
from typing import Tuple
from typing import Tuple, Callable, List
import numpy as np
Vector2D = Tuple[float, float]
IntVector2D = Tuple[int, int]
IntVector2DArrayType = []
IntVector2DArray = List[IntVector2D]
IntVector2DArrayArray = List[List[IntVector2D]]
Vector2DArray = List[Vector2D]
Vector2DArrayArray = List[List[Vector2D]]
IntVector2DDistance = Callable[[IntVector2D, IntVector2D], float]
class Vec2dOperations:
......@@ -73,42 +79,30 @@ class Vec2dOperations:
"""
return np.sqrt(node[0] * node[0] + node[1] * node[1])
@staticmethod
def get_manhattan_norm(node: Vector2D) -> float:
def get_euclidean_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the euclidean norm of the 2d vector
:param node: tuple with coordinate (x,y) or 2d vector
:return:
-------
returns the manhatten norm
returns the euclidean distance
"""
return abs(node[0] * node[0]) + abs(node[1] * node[1])
@staticmethod
def get_euclidean_distance(node_a: Vector2D,node_b: Vector2D) -> float:
"""
calculates the euclidean norm of the 2d vector
:param node: tuple with coordinate (x,y) or 2d vector
:return:
-------
returnss the manhatten distance
"""
return Vec2dOperations.get_norm(Vec2dOperations.subtract(node_b,node_a))
return Vec2dOperations.get_norm(Vec2dOperations.subtract(node_b, node_a))
@staticmethod
def get_manhattan_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the euclidean norm of the 2d vector
calculates the manhattan distance of the 2d vector
:param node: tuple with coordinate (x,y) or 2d vector
:return:
-------
returnss the manhatten distance
returns the manhattan distance
"""
return Vec2dOperations.get_manhattan_norm(Vec2dOperations.subtract(node_b, node_a))
delta = (Vec2dOperations.subtract(node_b, node_a))
return np.abs(delta[0]) + np.abs(delta[1])
@staticmethod
def normalize(node: Vector2D) -> Tuple[float, float]:
......
......@@ -8,7 +8,7 @@ from numpy import array
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position, get_direction
from flatland.core.grid.grid_utils import IntVector2DArrayType
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions
......@@ -302,7 +302,7 @@ class GridTransitionMap(TransitionMap):
self.height = new_height
self.grid = new_grid
def is_dead_end(self, rcPos: IntVector2DArrayType):
def is_dead_end(self, rcPos: IntVector2DArray):
"""
Check if the cell is a dead-end.
......@@ -322,7 +322,7 @@ class GridTransitionMap(TransitionMap):
tmp = tmp >> 1
return nbits == 1
def is_simple_turn(self, rcPos: IntVector2DArrayType):
def is_simple_turn(self, rcPos: IntVector2DArray):
"""
Check if the cell is a left/right simple turn
......@@ -349,7 +349,7 @@ class GridTransitionMap(TransitionMap):
return is_simple_turn(tmp)
def check_path_exists(self, start: IntVector2DArrayType, direction: int, end: IntVector2DArrayType):
def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray):
# print("_path_exists({},{},{}".format(start, direction, end))
# BFS - Check if a path exists between the 2 nodes
......@@ -373,7 +373,7 @@ class GridTransitionMap(TransitionMap):
return False
def cell_neighbours_valid(self, rcPos: IntVector2DArrayType, check_this_cell=False):
def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
......@@ -425,7 +425,7 @@ class GridTransitionMap(TransitionMap):
return True
def fix_neighbours(self, rcPos: IntVector2DArrayType, check_this_cell=False):
def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
......@@ -478,7 +478,7 @@ class GridTransitionMap(TransitionMap):
return True
def fix_transitions(self, rcPos: IntVector2DArrayType):
def fix_transitions(self, rcPos: IntVector2DArray):
"""
Fixes broken transitions
"""
......@@ -543,8 +543,8 @@ class GridTransitionMap(TransitionMap):
self.set_transitions((rcPos[0], rcPos[1]), transition)
return True
def validate_new_transition(self, prev_pos: IntVector2DArrayType, current_pos: IntVector2DArrayType,
new_pos: IntVector2DArrayType, end_pos: IntVector2DArrayType):
def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D,
new_pos: IntVector2D, end_pos: IntVector2D):
# start by getting direction used to get to current node
# and direction from current node to possible child node
......
......@@ -7,7 +7,8 @@ a GridTransitionMap object.
from flatland.core.grid.grid4_astar import a_star
from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
......@@ -15,12 +16,13 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi
start: IntVector2D,
end: IntVector2D,
flip_start_node_trans=False,
flip_end_node_trans=False):
flip_end_node_trans=False,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
"""
Creates a new path [start,end] in grid_map, based on rail_trans.
"""
# in the worst case we will need to do a A* search, so we might as well set that up
path = a_star(rail_trans, grid_map, start, end)
path = a_star(grid_map, start, end, a_star_distance_function)
if len(path) < 2:
return []
current_dir = get_direction(path[0], path[1])
......@@ -67,18 +69,25 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi
return path
def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D):
return connect_basic_operation(rail_trans, grid_map, start, end, True, True)
def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
return connect_basic_operation(rail_trans, grid_map, start, end, True, True, a_star_distance_function)
def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D):
return connect_basic_operation(rail_trans, grid_map, start, end, False, False)
def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
return connect_basic_operation(rail_trans, grid_map, start, end, False, False, a_star_distance_function)
def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D,
end: IntVector2D):
return connect_basic_operation(rail_trans, grid_map, start, end, False, True)
def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
return connect_basic_operation(rail_trans, grid_map, start, end, False, True, a_star_distance_function)
def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D):
return connect_basic_operation(rail_trans, grid_map, start, end, True, False)
def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment