Commit 594adf63 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

get shortestpath return type changed

parent d1971b42
Pipeline #2140 failed with stages
in 18 minutes and 40 seconds
......@@ -98,7 +98,6 @@ def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPath
distance = math.inf
while (position != a.target):
next_actions = get_valid_move_actions_(direction, position, distance_map.rail)
best_next_action = None
for next_action in next_actions:
next_action_distance = distance_map.get()[
......@@ -111,5 +110,9 @@ def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPath
position = best_next_action.next_position
direction = best_next_action.next_direction
if position == a.target:
shortest_paths[a.handle].append(
ShortestPathElement(position, direction,
RailEnvNextAction(RailEnvActions.DO_NOTHING, position, direction)))
return shortest_paths
return shortest_paths
......@@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
from flatland.envs.rail_env_utils import get_shortest_paths
from flatland.envs.rail_env_utils import get_shortest_paths, ShortestPathElement
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
......@@ -145,11 +145,18 @@ def test_shortest_path_predictor(rendering=False):
paths = get_shortest_paths(env.distance_map)[0]
assert paths == [
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), next_direction=0),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), next_direction=0),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9), next_direction=1)]
ShortestPathElement((5, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6),
next_direction=0)),
ShortestPathElement((4, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6),
next_direction=0)),
ShortestPathElement((3, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7),
next_direction=1)),
ShortestPathElement((3, 7), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8),
next_direction=1)),
ShortestPathElement((3, 8), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9),
next_direction=1)),
ShortestPathElement((3, 9), 1, RailEnvNextAction(action=RailEnvActions.DO_NOTHING, next_position=(3, 9),
next_direction=1))]
# extract the data
predictions = env.obs_builder.predictions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment