From d1971b42248627499af56b7c7397b4962d59073d Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 24 Sep 2019 19:52:43 +0200 Subject: [PATCH] get shortestpath return type changed --- flatland/envs/rail_env_utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 8e3ed5b4..02d1a3af 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -1,5 +1,5 @@ import math -from typing import Tuple, Set, Dict, List +from typing import Tuple, Set, Dict, List, NamedTuple import numpy as np @@ -14,6 +14,10 @@ from flatland.envs.rail_generators import rail_from_file from flatland.envs.schedule_generators import schedule_from_file from flatland.utils.ordered_set import OrderedSet +ShortestPathElement = \ + NamedTuple('Path_Element', + [('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)]) + def load_flatland_environment_from_file(file_name, load_from_package=None, obs_builder_object=None): if obs_builder_object is None: @@ -83,7 +87,7 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, return valid_actions -def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[RailEnvNextAction]]: +def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPathElement]]: # TODO: do we need to support unreachable targets? # TODO refactoring: unify with predictor (support agent.moving and max_depth) shortest_paths = dict() @@ -97,12 +101,15 @@ def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[RailEnvNextA best_next_action = None for next_action in next_actions: - next_action_distance = distance_map.get()[a.handle, next_action.next_position[0], next_action.next_position[1], next_action.next_direction] + next_action_distance = distance_map.get()[ + a.handle, next_action.next_position[0], next_action.next_position[1], next_action.next_direction] if next_action_distance < distance: best_next_action = next_action distance = next_action_distance + + shortest_paths[a.handle].append(ShortestPathElement(position, direction, best_next_action)) + position = best_next_action.next_position direction = best_next_action.next_direction - shortest_paths[a.handle].append(best_next_action) return shortest_paths -- GitLab