From 594adf639edc67806fca35ebe0f9d0454358ccfa Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 24 Sep 2019 20:00:45 +0200 Subject: [PATCH] get shortestpath return type changed --- flatland/envs/rail_env_utils.py | 7 +++++-- tests/test_flatland_envs_predictions.py | 19 +++++++++++++------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 02d1a3af..b23347da 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -98,7 +98,6 @@ def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPath distance = math.inf while (position != a.target): next_actions = get_valid_move_actions_(direction, position, distance_map.rail) - best_next_action = None for next_action in next_actions: next_action_distance = distance_map.get()[ @@ -111,5 +110,9 @@ def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPath position = best_next_action.next_position direction = best_next_action.next_direction + if position == a.target: + shortest_paths[a.handle].append( + ShortestPathElement(position, direction, + RailEnvNextAction(RailEnvActions.DO_NOTHING, position, direction))) - return shortest_paths + return shortest_paths diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 2df60017..d891ccaa 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction -from flatland.envs.rail_env_utils import get_shortest_paths +from flatland.envs.rail_env_utils import get_shortest_paths, ShortestPathElement from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool @@ -145,11 +145,18 @@ def test_shortest_path_predictor(rendering=False): paths = get_shortest_paths(env.distance_map)[0] assert paths == [ - RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), next_direction=0), - RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), next_direction=0), - RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7), next_direction=1), - RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8), next_direction=1), - RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9), next_direction=1)] + ShortestPathElement((5, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), + next_direction=0)), + ShortestPathElement((4, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), + next_direction=0)), + ShortestPathElement((3, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7), + next_direction=1)), + ShortestPathElement((3, 7), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8), + next_direction=1)), + ShortestPathElement((3, 8), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9), + next_direction=1)), + ShortestPathElement((3, 9), 1, RailEnvNextAction(action=RailEnvActions.DO_NOTHING, next_position=(3, 9), + next_direction=1))] # extract the data predictions = env.obs_builder.predictions -- GitLab