diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index 11abf7e2da57db7b87dd20ddfa7f8d32d0806c3c..7f155cf2c72f14231993c644df29ae215bd29903 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -5,7 +5,7 @@ import warnings import numpy as np -from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d, IntVector2DArrayType from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail @@ -15,7 +15,6 @@ from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import AgentRenderVariant, RenderTool -IntVector2DArrayType = [] FloatArrayType = [] diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index feb72313f21b9ecc989688d63ba02ccf3a458107..3b4e69380a0cb3f3e69d2b6a08ab145f7a20ce04 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -1,7 +1,11 @@ from flatland.core.grid.grid4_utils import validate_new_transition +from flatland.core.grid.grid_utils import IntVector2D +from flatland.core.grid.grid_utils import IntVector2DArrayType +from flatland.core.grid.rail_env_grid import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap -class AStarNode(): +class AStarNode: """A node class for A* Pathfinding""" def __init__(self, parent=None, pos=None): @@ -25,12 +29,13 @@ class AStarNode(): self.f = other.f -def a_star(rail_trans, rail_array, start, end): +def a_star(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D) -> \ + (IntVector2DArrayType): """ Returns a list of tuples as a path from the given start to end. If no path is found, returns path to closest point to end. """ - rail_shape = rail_array.shape + rail_shape = grid_map.grid.shape start_node = AStarNode(None, start) end_node = AStarNode(None, end) open_nodes = set() @@ -73,7 +78,8 @@ def a_star(rail_trans, rail_array, start, end): continue # validate positions - if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos): + if not validate_new_transition(rail_trans, grid_map.grid, prev_pos, current_node.pos, node_pos, + end_node.pos): continue # create new node diff --git a/flatland/core/grid/grid_utils.py b/flatland/core/grid/grid_utils.py index b7bf988fb754b30d4784e0feef36176fa5aa28df..ffbf79a7b99e1628d52814638c753811ad9f7282 100644 --- a/flatland/core/grid/grid_utils.py +++ b/flatland/core/grid/grid_utils.py @@ -5,6 +5,8 @@ import numpy as np Vector2D = Tuple[float, float] IntVector2D = Tuple[int, int] +IntVector2DArrayType = [] + class Vec2dOperations: diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index c1bdedebd3edef2c22ef971a538fa55c5c1482a6..0ead334f4ac8a87f53c0f10bb74855050dc89e62 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -20,7 +20,7 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi Creates a new path [start,end] in grid_map, based on rail_trans. """ # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, grid_map.grid, start, end) + path = a_star(rail_trans, grid_map, start, end) if len(path) < 2: return [] current_dir = get_direction(path[0], path[1])