Skip to content
Snippets Groups Projects
Commit c2fa69ff authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

refactoring and clean up

parent a0ca4034
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,7 @@ import warnings ...@@ -5,7 +5,7 @@ import warnings
import numpy as np import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d, IntVector2DArrayType
from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail
...@@ -15,7 +15,6 @@ from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct ...@@ -15,7 +15,6 @@ from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct
from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import AgentRenderVariant, RenderTool from flatland.utils.rendertools import AgentRenderVariant, RenderTool
IntVector2DArrayType = []
FloatArrayType = [] FloatArrayType = []
......
from flatland.core.grid.grid4_utils import validate_new_transition from flatland.core.grid.grid4_utils import validate_new_transition
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.grid.grid_utils import IntVector2DArrayType
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
class AStarNode(): class AStarNode:
"""A node class for A* Pathfinding""" """A node class for A* Pathfinding"""
def __init__(self, parent=None, pos=None): def __init__(self, parent=None, pos=None):
...@@ -25,12 +29,13 @@ class AStarNode(): ...@@ -25,12 +29,13 @@ class AStarNode():
self.f = other.f self.f = other.f
def a_star(rail_trans, rail_array, start, end): def a_star(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D) -> \
(IntVector2DArrayType):
""" """
Returns a list of tuples as a path from the given start to end. Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end. If no path is found, returns path to closest point to end.
""" """
rail_shape = rail_array.shape rail_shape = grid_map.grid.shape
start_node = AStarNode(None, start) start_node = AStarNode(None, start)
end_node = AStarNode(None, end) end_node = AStarNode(None, end)
open_nodes = set() open_nodes = set()
...@@ -73,7 +78,8 @@ def a_star(rail_trans, rail_array, start, end): ...@@ -73,7 +78,8 @@ def a_star(rail_trans, rail_array, start, end):
continue continue
# validate positions # validate positions
if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos): if not validate_new_transition(rail_trans, grid_map.grid, prev_pos, current_node.pos, node_pos,
end_node.pos):
continue continue
# create new node # create new node
......
...@@ -5,6 +5,8 @@ import numpy as np ...@@ -5,6 +5,8 @@ import numpy as np
Vector2D = Tuple[float, float] Vector2D = Tuple[float, float]
IntVector2D = Tuple[int, int] IntVector2D = Tuple[int, int]
IntVector2DArrayType = []
class Vec2dOperations: class Vec2dOperations:
......
...@@ -20,7 +20,7 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi ...@@ -20,7 +20,7 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi
Creates a new path [start,end] in grid_map, based on rail_trans. Creates a new path [start,end] in grid_map, based on rail_trans.
""" """
# in the worst case we will need to do a A* search, so we might as well set that up # in the worst case we will need to do a A* search, so we might as well set that up
path = a_star(rail_trans, grid_map.grid, start, end) path = a_star(rail_trans, grid_map, start, end)
if len(path) < 2: if len(path) < 2:
return [] return []
current_dir = get_direction(path[0], path[1]) current_dir = get_direction(path[0], path[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment