diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index 97444eccc164884ff7720203197140d9f1a8c401..6e5e6a566b7f2e462d310969be8585efec431970 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -230,13 +230,13 @@ def realistic_rail_generator(num_cities=5, org_s_nodes: IntVector2DArrayType, org_e_nodes: IntVector2DArrayType, nodes_added: IntVector2DArrayType, - inter_connect_max_nbr_of_shortes_city: int): + intern_connect_max_nbr_of_shortes_city: int): city_edges = [] s_nodes = copy.deepcopy(org_s_nodes) e_nodes = copy.deepcopy(org_e_nodes) - for k in range(inter_connect_max_nbr_of_shortes_city): + for nbr_connected in range(intern_connect_max_nbr_of_shortes_city): for city_loop in range(len(s_nodes)): sns = s_nodes[city_loop] for start_node in sns: @@ -278,13 +278,14 @@ def realistic_rail_generator(num_cities=5, grid_map.grid[start_node] = tmp_trans_sn grid_map.grid[end_node] = tmp_trans_en + connect_sub_graphs(rail_trans, grid_map, org_s_nodes, org_e_nodes, city_edges, nodes_added) def connect_random_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start_nodes_added: IntVector2DArrayType, end_nodes_added: IntVector2DArrayType, nodes_added: IntVector2DArrayType, - inter_connect_max_nbr_of_shortes_city: int): + intern_connect_max_nbr_of_shortes_city: int): x = np.arange(len(start_nodes_added)) random_city_idx = np.random.choice(x, len(x), False) @@ -298,14 +299,10 @@ def realistic_rail_generator(num_cities=5, e_nodes = end_nodes_added[idx_b] max_input_output = max(len(s_nodes), len(e_nodes)) - max_input_output = min(inter_connect_max_nbr_of_shortes_city, max_input_output) + max_input_output = min(intern_connect_max_nbr_of_shortes_city, max_input_output) - if do_random_connect_stations: - idx_s_nodes = np.random.choice(np.arange(len(s_nodes)), len(s_nodes), False) - idx_e_nodes = np.random.choice(np.arange(len(e_nodes)), len(e_nodes), False) - else: - idx_s_nodes = np.arange(len(s_nodes)) - idx_e_nodes = np.arange(len(e_nodes)) + idx_s_nodes = np.random.choice(np.arange(len(s_nodes)), len(s_nodes), False) + idx_e_nodes = np.random.choice(np.arange(len(e_nodes)), len(e_nodes), False) if len(idx_s_nodes) < max_input_output: idx_s_nodes = np.append(idx_s_nodes, np.random.choice(np.arange(len(s_nodes)), max_input_output - len( @@ -315,10 +312,10 @@ def realistic_rail_generator(num_cities=5, np.random.choice(np.arange(len(idx_e_nodes)), max_input_output - len( idx_e_nodes))) - if len(idx_s_nodes) > inter_connect_max_nbr_of_shortes_city: - idx_s_nodes = np.random.choice(idx_s_nodes, inter_connect_max_nbr_of_shortes_city, False) - if len(idx_e_nodes) > inter_connect_max_nbr_of_shortes_city: - idx_e_nodes = np.random.choice(idx_e_nodes, inter_connect_max_nbr_of_shortes_city, False) + if len(idx_s_nodes) > intern_connect_max_nbr_of_shortes_city: + idx_s_nodes = np.random.choice(idx_s_nodes, intern_connect_max_nbr_of_shortes_city, False) + if len(idx_e_nodes) > intern_connect_max_nbr_of_shortes_city: + idx_e_nodes = np.random.choice(idx_e_nodes, intern_connect_max_nbr_of_shortes_city, False) for i in range(max_input_output): start_node = s_nodes[idx_s_nodes[i]] @@ -360,12 +357,12 @@ def realistic_rail_generator(num_cities=5, if print_out_info: print("intern_nbr_of_switches_per_station_track:", intern_nbr_of_switches_per_station_track) - inter_connect_max_nbr_of_shortes_city = connect_max_nbr_of_shortes_city + intern_connect_max_nbr_of_shortes_city = connect_max_nbr_of_shortes_city if connect_max_nbr_of_shortes_city < 1: - warnings.warn("min inter_connect_max_nbr_of_shortes_city requried to be > 1!") - inter_connect_max_nbr_of_shortes_city = 1 + warnings.warn("min intern_connect_max_nbr_of_shortes_city requried to be > 1!") + intern_connect_max_nbr_of_shortes_city = 1 if print_out_info: - print("inter_connect_max_nbr_of_shortes_city:", inter_connect_max_nbr_of_shortes_city) + print("intern_connect_max_nbr_of_shortes_city:", intern_connect_max_nbr_of_shortes_city) agent_start_targets_nodes = [] @@ -395,10 +392,10 @@ def realistic_rail_generator(num_cities=5, if True: if do_random_connect_stations: connect_random_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, - inter_connect_max_nbr_of_shortes_city) + intern_connect_max_nbr_of_shortes_city) else: connect_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, - inter_connect_max_nbr_of_shortes_city) + intern_connect_max_nbr_of_shortes_city) # ---------------------------------------------------------------------------------- # fix all transition at starting / ending points (mostly add a dead end, if missing) @@ -459,13 +456,13 @@ if os.path.exists("./../render_output/"): height=40 + np.random.choice(100), rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10), city_size=10 + np.random.choice(10), - allowed_rotation_angles=np.arange(-180, 180, 15), + allowed_rotation_angles=[0], max_number_of_station_tracks=np.random.choice(4) + 4, nbr_of_switches_per_station_track=np.random.choice(4) + 2, connect_max_nbr_of_shortes_city=2, - do_random_connect_stations=np.random.choice(1) == 0, + do_random_connect_stations=False, # Number of cities in map - seed=int(time.time()), # Random seed + seed=0*int(time.time()), # Random seed print_out_info=True ), schedule_generator=sparse_schedule_generator(), diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index 5bec1ce454df9e6c41ad8c48fae243116f4ff222..c04d71dd57c8ea90d3beb592c4c83329d22e7b80 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -9,14 +9,14 @@ from flatland.core.transition_map import GridTransitionMap class AStarNode: """A node class for A* Pathfinding""" - def __init__(self, parent=None, pos=None): - self.parent = parent - self.pos = pos - self.g = 0 - self.h = 0 - self.f = 0 - - def __eq__(self, other): + def __init__(self, parent: IntVector2D = None, pos: IntVector2D = None): + self.parent: IntVector2D = parent + self.pos: IntVector2D = pos + self.g: float = 0.0 + self.h: float = 0.0 + self.f: float = 0.0 + + def __eq__(self, other: IntVector2D): return self.pos == other.pos def __hash__(self): @@ -95,7 +95,7 @@ def a_star(rail_trans: RailEnvTransitions, continue # create the f, g, and h values - child.g = current_node.g + 1 + child.g = current_node.g + 1.0 # this heuristic avoids diagonal paths child.h = Vec2d.get_manhattan_distance(child.pos, end_node.pos) child.f = child.g + child.h