Skip to content
Snippets Groups Projects
Commit 1bba6852 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

renamed ShortestPathElement -> WalkingElement (naming fits better)

parent cf66525c
No related branches found
No related tags found
No related merge requests found
...@@ -14,7 +14,7 @@ from flatland.envs.rail_generators import rail_from_file ...@@ -14,7 +14,7 @@ from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file from flatland.envs.schedule_generators import schedule_from_file
from flatland.utils.ordered_set import OrderedSet from flatland.utils.ordered_set import OrderedSet
ShortestPathElement = \ WalkingElement = \
NamedTuple('Path_Element', NamedTuple('Path_Element',
[('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)]) [('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)])
...@@ -87,7 +87,7 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, ...@@ -87,7 +87,7 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
return valid_actions return valid_actions
def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPathElement]]: def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[WalkingElement]]:
# TODO: do we need to support unreachable targets? # TODO: do we need to support unreachable targets?
# TODO refactoring: unify with predictor (support agent.moving and max_depth) # TODO refactoring: unify with predictor (support agent.moving and max_depth)
shortest_paths = dict() shortest_paths = dict()
...@@ -106,13 +106,13 @@ def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPath ...@@ -106,13 +106,13 @@ def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPath
best_next_action = next_action best_next_action = next_action
distance = next_action_distance distance = next_action_distance
shortest_paths[a.handle].append(ShortestPathElement(position, direction, best_next_action)) shortest_paths[a.handle].append(WalkingElement(position, direction, best_next_action))
position = best_next_action.next_position position = best_next_action.next_position
direction = best_next_action.next_direction direction = best_next_action.next_direction
shortest_paths[a.handle].append( shortest_paths[a.handle].append(
ShortestPathElement(position, direction, WalkingElement(position, direction,
RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction))) RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction)))
return shortest_paths return shortest_paths
...@@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum ...@@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
from flatland.envs.rail_env_utils import get_shortest_paths, ShortestPathElement from flatland.envs.rail_env_utils import get_shortest_paths, WalkingElement
from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
...@@ -145,18 +145,18 @@ def test_shortest_path_predictor(rendering=False): ...@@ -145,18 +145,18 @@ def test_shortest_path_predictor(rendering=False):
paths = get_shortest_paths(env.distance_map)[0] paths = get_shortest_paths(env.distance_map)[0]
assert paths == [ assert paths == [
ShortestPathElement((5, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), WalkingElement((5, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6),
next_direction=0)), next_direction=0)),
ShortestPathElement((4, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), WalkingElement((4, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6),
next_direction=0)), next_direction=0)),
ShortestPathElement((3, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7), WalkingElement((3, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7),
next_direction=1)), next_direction=1)),
ShortestPathElement((3, 7), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8), WalkingElement((3, 7), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8),
next_direction=1)), next_direction=1)),
ShortestPathElement((3, 8), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9), WalkingElement((3, 8), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9),
next_direction=1)), next_direction=1)),
ShortestPathElement((3, 9), 1, RailEnvNextAction(action=RailEnvActions.STOP_MOVING, next_position=(3, 9), WalkingElement((3, 9), 1, RailEnvNextAction(action=RailEnvActions.STOP_MOVING, next_position=(3, 9),
next_direction=1))] next_direction=1))]
# extract the data # extract the data
predictions = env.obs_builder.predictions predictions = env.obs_builder.predictions
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment