Skip to content
Snippets Groups Projects
Commit 11feb9b0 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

refactoring and clean up

parent 8aa2ef7b
No related branches found
No related tags found
No related merge requests found
...@@ -230,13 +230,13 @@ def realistic_rail_generator(num_cities=5, ...@@ -230,13 +230,13 @@ def realistic_rail_generator(num_cities=5,
org_s_nodes: IntVector2DArrayType, org_s_nodes: IntVector2DArrayType,
org_e_nodes: IntVector2DArrayType, org_e_nodes: IntVector2DArrayType,
nodes_added: IntVector2DArrayType, nodes_added: IntVector2DArrayType,
inter_connect_max_nbr_of_shortes_city: int): intern_connect_max_nbr_of_shortes_city: int):
city_edges = [] city_edges = []
s_nodes = copy.deepcopy(org_s_nodes) s_nodes = copy.deepcopy(org_s_nodes)
e_nodes = copy.deepcopy(org_e_nodes) e_nodes = copy.deepcopy(org_e_nodes)
for k in range(inter_connect_max_nbr_of_shortes_city): for nbr_connected in range(intern_connect_max_nbr_of_shortes_city):
for city_loop in range(len(s_nodes)): for city_loop in range(len(s_nodes)):
sns = s_nodes[city_loop] sns = s_nodes[city_loop]
for start_node in sns: for start_node in sns:
...@@ -278,13 +278,14 @@ def realistic_rail_generator(num_cities=5, ...@@ -278,13 +278,14 @@ def realistic_rail_generator(num_cities=5,
grid_map.grid[start_node] = tmp_trans_sn grid_map.grid[start_node] = tmp_trans_sn
grid_map.grid[end_node] = tmp_trans_en grid_map.grid[end_node] = tmp_trans_en
connect_sub_graphs(rail_trans, grid_map, org_s_nodes, org_e_nodes, city_edges, nodes_added) connect_sub_graphs(rail_trans, grid_map, org_s_nodes, org_e_nodes, city_edges, nodes_added)
def connect_random_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, def connect_random_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start_nodes_added: IntVector2DArrayType, start_nodes_added: IntVector2DArrayType,
end_nodes_added: IntVector2DArrayType, end_nodes_added: IntVector2DArrayType,
nodes_added: IntVector2DArrayType, nodes_added: IntVector2DArrayType,
inter_connect_max_nbr_of_shortes_city: int): intern_connect_max_nbr_of_shortes_city: int):
x = np.arange(len(start_nodes_added)) x = np.arange(len(start_nodes_added))
random_city_idx = np.random.choice(x, len(x), False) random_city_idx = np.random.choice(x, len(x), False)
...@@ -298,14 +299,10 @@ def realistic_rail_generator(num_cities=5, ...@@ -298,14 +299,10 @@ def realistic_rail_generator(num_cities=5,
e_nodes = end_nodes_added[idx_b] e_nodes = end_nodes_added[idx_b]
max_input_output = max(len(s_nodes), len(e_nodes)) max_input_output = max(len(s_nodes), len(e_nodes))
max_input_output = min(inter_connect_max_nbr_of_shortes_city, max_input_output) max_input_output = min(intern_connect_max_nbr_of_shortes_city, max_input_output)
if do_random_connect_stations: idx_s_nodes = np.random.choice(np.arange(len(s_nodes)), len(s_nodes), False)
idx_s_nodes = np.random.choice(np.arange(len(s_nodes)), len(s_nodes), False) idx_e_nodes = np.random.choice(np.arange(len(e_nodes)), len(e_nodes), False)
idx_e_nodes = np.random.choice(np.arange(len(e_nodes)), len(e_nodes), False)
else:
idx_s_nodes = np.arange(len(s_nodes))
idx_e_nodes = np.arange(len(e_nodes))
if len(idx_s_nodes) < max_input_output: if len(idx_s_nodes) < max_input_output:
idx_s_nodes = np.append(idx_s_nodes, np.random.choice(np.arange(len(s_nodes)), max_input_output - len( idx_s_nodes = np.append(idx_s_nodes, np.random.choice(np.arange(len(s_nodes)), max_input_output - len(
...@@ -315,10 +312,10 @@ def realistic_rail_generator(num_cities=5, ...@@ -315,10 +312,10 @@ def realistic_rail_generator(num_cities=5,
np.random.choice(np.arange(len(idx_e_nodes)), max_input_output - len( np.random.choice(np.arange(len(idx_e_nodes)), max_input_output - len(
idx_e_nodes))) idx_e_nodes)))
if len(idx_s_nodes) > inter_connect_max_nbr_of_shortes_city: if len(idx_s_nodes) > intern_connect_max_nbr_of_shortes_city:
idx_s_nodes = np.random.choice(idx_s_nodes, inter_connect_max_nbr_of_shortes_city, False) idx_s_nodes = np.random.choice(idx_s_nodes, intern_connect_max_nbr_of_shortes_city, False)
if len(idx_e_nodes) > inter_connect_max_nbr_of_shortes_city: if len(idx_e_nodes) > intern_connect_max_nbr_of_shortes_city:
idx_e_nodes = np.random.choice(idx_e_nodes, inter_connect_max_nbr_of_shortes_city, False) idx_e_nodes = np.random.choice(idx_e_nodes, intern_connect_max_nbr_of_shortes_city, False)
for i in range(max_input_output): for i in range(max_input_output):
start_node = s_nodes[idx_s_nodes[i]] start_node = s_nodes[idx_s_nodes[i]]
...@@ -360,12 +357,12 @@ def realistic_rail_generator(num_cities=5, ...@@ -360,12 +357,12 @@ def realistic_rail_generator(num_cities=5,
if print_out_info: if print_out_info:
print("intern_nbr_of_switches_per_station_track:", intern_nbr_of_switches_per_station_track) print("intern_nbr_of_switches_per_station_track:", intern_nbr_of_switches_per_station_track)
inter_connect_max_nbr_of_shortes_city = connect_max_nbr_of_shortes_city intern_connect_max_nbr_of_shortes_city = connect_max_nbr_of_shortes_city
if connect_max_nbr_of_shortes_city < 1: if connect_max_nbr_of_shortes_city < 1:
warnings.warn("min inter_connect_max_nbr_of_shortes_city requried to be > 1!") warnings.warn("min intern_connect_max_nbr_of_shortes_city requried to be > 1!")
inter_connect_max_nbr_of_shortes_city = 1 intern_connect_max_nbr_of_shortes_city = 1
if print_out_info: if print_out_info:
print("inter_connect_max_nbr_of_shortes_city:", inter_connect_max_nbr_of_shortes_city) print("intern_connect_max_nbr_of_shortes_city:", intern_connect_max_nbr_of_shortes_city)
agent_start_targets_nodes = [] agent_start_targets_nodes = []
...@@ -395,10 +392,10 @@ def realistic_rail_generator(num_cities=5, ...@@ -395,10 +392,10 @@ def realistic_rail_generator(num_cities=5,
if True: if True:
if do_random_connect_stations: if do_random_connect_stations:
connect_random_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, connect_random_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added,
inter_connect_max_nbr_of_shortes_city) intern_connect_max_nbr_of_shortes_city)
else: else:
connect_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, connect_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added,
inter_connect_max_nbr_of_shortes_city) intern_connect_max_nbr_of_shortes_city)
# ---------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------
# fix all transition at starting / ending points (mostly add a dead end, if missing) # fix all transition at starting / ending points (mostly add a dead end, if missing)
...@@ -459,13 +456,13 @@ if os.path.exists("./../render_output/"): ...@@ -459,13 +456,13 @@ if os.path.exists("./../render_output/"):
height=40 + np.random.choice(100), height=40 + np.random.choice(100),
rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10), rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10),
city_size=10 + np.random.choice(10), city_size=10 + np.random.choice(10),
allowed_rotation_angles=np.arange(-180, 180, 15), allowed_rotation_angles=[0],
max_number_of_station_tracks=np.random.choice(4) + 4, max_number_of_station_tracks=np.random.choice(4) + 4,
nbr_of_switches_per_station_track=np.random.choice(4) + 2, nbr_of_switches_per_station_track=np.random.choice(4) + 2,
connect_max_nbr_of_shortes_city=2, connect_max_nbr_of_shortes_city=2,
do_random_connect_stations=np.random.choice(1) == 0, do_random_connect_stations=False,
# Number of cities in map # Number of cities in map
seed=int(time.time()), # Random seed seed=0*int(time.time()), # Random seed
print_out_info=True print_out_info=True
), ),
schedule_generator=sparse_schedule_generator(), schedule_generator=sparse_schedule_generator(),
......
...@@ -9,14 +9,14 @@ from flatland.core.transition_map import GridTransitionMap ...@@ -9,14 +9,14 @@ from flatland.core.transition_map import GridTransitionMap
class AStarNode: class AStarNode:
"""A node class for A* Pathfinding""" """A node class for A* Pathfinding"""
def __init__(self, parent=None, pos=None): def __init__(self, parent: IntVector2D = None, pos: IntVector2D = None):
self.parent = parent self.parent: IntVector2D = parent
self.pos = pos self.pos: IntVector2D = pos
self.g = 0 self.g: float = 0.0
self.h = 0 self.h: float = 0.0
self.f = 0 self.f: float = 0.0
def __eq__(self, other): def __eq__(self, other: IntVector2D):
return self.pos == other.pos return self.pos == other.pos
def __hash__(self): def __hash__(self):
...@@ -95,7 +95,7 @@ def a_star(rail_trans: RailEnvTransitions, ...@@ -95,7 +95,7 @@ def a_star(rail_trans: RailEnvTransitions,
continue continue
# create the f, g, and h values # create the f, g, and h values
child.g = current_node.g + 1 child.g = current_node.g + 1.0
# this heuristic avoids diagonal paths # this heuristic avoids diagonal paths
child.h = Vec2d.get_manhattan_distance(child.pos, end_node.pos) child.h = Vec2d.get_manhattan_distance(child.pos, end_node.pos)
child.f = child.g + child.h child.f = child.g + child.h
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment