From d1971b42248627499af56b7c7397b4962d59073d Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 24 Sep 2019 19:52:43 +0200
Subject: [PATCH] get shortestpath return type changed

---
 flatland/envs/rail_env_utils.py | 15 +++++++++++----
 1 file changed, 11 insertions(+), 4 deletions(-)

diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py
index 8e3ed5b4..02d1a3af 100644
--- a/flatland/envs/rail_env_utils.py
+++ b/flatland/envs/rail_env_utils.py
@@ -1,5 +1,5 @@
 import math
-from typing import Tuple, Set, Dict, List
+from typing import Tuple, Set, Dict, List, NamedTuple
 
 import numpy as np
 
@@ -14,6 +14,10 @@ from flatland.envs.rail_generators import rail_from_file
 from flatland.envs.schedule_generators import schedule_from_file
 from flatland.utils.ordered_set import OrderedSet
 
+ShortestPathElement = \
+    NamedTuple('Path_Element',
+               [('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)])
+
 
 def load_flatland_environment_from_file(file_name, load_from_package=None, obs_builder_object=None):
     if obs_builder_object is None:
@@ -83,7 +87,7 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
     return valid_actions
 
 
-def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[RailEnvNextAction]]:
+def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPathElement]]:
     # TODO: do we need to support unreachable targets?
     # TODO refactoring: unify with predictor (support agent.moving and max_depth)
     shortest_paths = dict()
@@ -97,12 +101,15 @@ def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[RailEnvNextA
 
             best_next_action = None
             for next_action in next_actions:
-                next_action_distance = distance_map.get()[a.handle, next_action.next_position[0], next_action.next_position[1], next_action.next_direction]
+                next_action_distance = distance_map.get()[
+                    a.handle, next_action.next_position[0], next_action.next_position[1], next_action.next_direction]
                 if next_action_distance < distance:
                     best_next_action = next_action
                     distance = next_action_distance
+
+            shortest_paths[a.handle].append(ShortestPathElement(position, direction, best_next_action))
+
             position = best_next_action.next_position
             direction = best_next_action.next_direction
-            shortest_paths[a.handle].append(best_next_action)
 
     return shortest_paths
-- 
GitLab