From 1588496816810b7563aba21669688445efaa866c Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Mon, 17 Jun 2019 14:42:59 +0200
Subject: [PATCH] 66 shortest-path-predictor: cleanup and unit test

---
 flatland/core/env_prediction_builder.py |   8 +-
 flatland/core/transitions.py            |  28 ++++
 flatland/envs/observations.py           |  28 ++--
 flatland/envs/predictions.py            |  59 ++++----
 tests/test_env_observation_builder.py   |   8 --
 tests/test_env_prediction_builder.py    | 179 +++++++++++++++++++++---
 tests/test_environments.py              |   5 -
 tests/test_player.py                    |   4 -
 8 files changed, 244 insertions(+), 75 deletions(-)

diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index 060dbfc3..5ce69a81 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -13,6 +13,7 @@ case of multi-agent environments.
 class PredictionBuilder:
     """
     PredictionBuilder base class.
+
     """
 
     def __init__(self, max_depth: int = 20):
@@ -27,12 +28,15 @@ class PredictionBuilder:
         """
         pass
 
-    def get(self, handle=0):
+    def get(self, custom_args=None, handle=0):
         """
-        Called whenever predict is called on the environment.
+        Called whenever get_many in the observation build is called.
 
         Parameters
         -------
+        custom_args: dict
+            Implementation-dependent custom arguments, see the sub-classes.
+
         handle : int (optional)
             Handle of the agent for which to compute the observation vector.
 
diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index a6d1bb07..6c38a39c 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -3,6 +3,7 @@ The transitions module defines the base Transitions class and a
 derived GridTransitions class, which allows for the specification of
 possible transitions over a 2D grid.
 """
+from enum import IntEnum
 
 import numpy as np
 
@@ -129,6 +130,16 @@ class Transitions:
         """
         raise NotImplementedError()
 
+    def get_direction_enum(self) -> IntEnum:
+        raise NotImplementedError()
+
+
+class Grid4TransitionsEnum(IntEnum):
+    NORTH = 0
+    EAST = 1
+    SOUTH = 2
+    WEST = 3
+
 
 class Grid4Transitions(Transitions):
     """
@@ -323,6 +334,20 @@ class Grid4Transitions(Transitions):
         cell_transition = value
         return cell_transition
 
+    def get_direction_enum(self) -> IntEnum:
+        return Grid4TransitionsEnum
+
+
+class Grid8TransitionsEnum(IntEnum):
+    NORTH = 0
+    NORTH_EAST = 1
+    EAST = 2
+    SOUTH_EAST = 3
+    SOUTH = 4
+    SOUTH_WEST = 5
+    WEST = 6
+    NORTH_WEST = 7
+
 
 class Grid8Transitions(Transitions):
     """
@@ -504,6 +529,9 @@ class Grid8Transitions(Transitions):
 
         return cell_transition
 
+    def get_direction_enum(self) -> IntEnum:
+        return Grid8TransitionsEnum
+
 
 class RailEnvTransitions(Grid4Transitions):
     """
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 4b0049f6..d7fdcee7 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -48,16 +48,19 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.agents_previous_reset = agents
 
         if compute_distance_map:
-            self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
-                                                        self.env.height,
-                                                        self.env.width,
-                                                        4))
-            self.max_dist = np.zeros(nAgents)
+            self._compute_distance_map()
 
-            self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
-
-            # Update local lookup table for all agents' target locations
-            self.location_has_target = {tuple(agent.target): 1 for agent in agents}
+    def _compute_distance_map(self):
+        agents = self.env.agents
+        nAgents = len(agents)
+        self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
+                                                    self.env.height,
+                                                    self.env.width,
+                                                    4))
+        self.max_dist = np.zeros(nAgents)
+        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
+        # Update local lookup table for all agents' target locations
+        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
 
     def _distance_map_walker(self, position, target_nr):
         """
@@ -177,7 +180,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         if self.predictor:
             self.predicted_pos = {}
             self.predicted_dir = {}
-            self.predictions = self.predictor.get(self.distance_map)
+            self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
             for t in range(len(self.predictions[0])):
                 pos_list = []
                 dir_list = []
@@ -796,8 +799,3 @@ class LocalObsForRailEnv(ObservationBuilder):
         direction = self._get_one_hot_for_agent_direction(agent)
 
         return local_rail_obs, obs_map_state, obs_other_agents_state, direction
-
-# class LocalObsForRailEnvImproved(ObservationBuilder):
-#     """
-#     Returns a local observation around the given agent
-#     """
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index b6fe8631..3910fa1b 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -16,24 +16,28 @@ class DummyPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def get(self, distancemap, handle=None):
+    def get(self, custom_args=None, handle=None):
         """
-        Called whenever predict is called on the environment.
+        Called whenever get_many in the observation build is called.
 
         Parameters
         -------
+        custom_args: dict
+            Not used in this dummy implementation.
         handle : int (optional)
             Handle of the agent for which to compute the observation vector.
 
         Returns
         -------
-        function
-            Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
+        np.array
+            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1) x 5 elements:
             - time_offset
             - position axis 0
             - position axis 1
             - direction
             - action taken to come here
+            The prediction at 0 is the current position, direction etc.
+
         """
         agents = self.env.agents
         if handle:
@@ -46,12 +50,12 @@ class DummyPredictorForRailEnv(PredictionBuilder):
             _agent_initial_position = agent.position
             _agent_initial_direction = agent.direction
             prediction = np.zeros(shape=(self.max_depth + 1, 5))
-            prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
+            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
             for index in range(1, self.max_depth + 1):
                 action_done = False
                 # if we're at the target, stop moving...
                 if agent.position == agent.target:
-                    prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
+                    prediction[index] = [index, *agent.target, agent.direction,
                                          RailEnvActions.STOP_MOVING]
 
                     continue
@@ -63,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
                         # performed
                         agent.position = new_position
                         agent.direction = new_direction
-                        prediction[index] = [index, new_position[0], new_position[1], new_direction, action]
+                        prediction[index] = [index, *new_position, new_direction, action]
                         action_done = True
                         break
                 if not action_done:
@@ -76,51 +80,55 @@ class DummyPredictorForRailEnv(PredictionBuilder):
 
 class ShortestPathPredictorForRailEnv(PredictionBuilder):
     """
-    DummyPredictorForRailEnv object.
+    ShortestPathPredictorForRailEnv object.
 
-    This object returns predictions for agents in the RailEnv environment.
+    This object returns shortest-path predictions for agents in the RailEnv environment.
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def get(self, distancemap, handle=None):
+    def get(self, custom_args=None, handle=None):
         """
-        Called whenever predict is called on the environment.
+        Called whenever get_many in the observation build is called.
+        Requires distance_map to extract the shortest path.
 
         Parameters
         -------
+        custom_args: dict
+            - distance_map : dict
         handle : int (optional)
             Handle of the agent for which to compute the observation vector.
 
         Returns
         -------
-        function
-            Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
+        np.array
+            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1) x 5 elements:
             - time_offset
             - position axis 0
             - position axis 1
             - direction
             - action taken to come here
+            The prediction at 0 is the current position, direction etc.
         """
         agents = self.env.agents
         if handle:
             agents = [self.env.agents[handle]]
+        assert custom_args
+        distance_map = custom_args.get('distance_map')
+        assert distance_map is not None
 
         prediction_dict = {}
-        agent_idx = 0
         for agent in agents:
             _agent_initial_position = agent.position
             _agent_initial_direction = agent.direction
             prediction = np.zeros(shape=(self.max_depth + 1, 5))
-            prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
+            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
             for index in range(1, self.max_depth + 1):
                 # if we're at the target, stop moving...
                 if agent.position == agent.target:
-                    prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
-                                         RailEnvActions.STOP_MOVING]
+                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
                     continue
                 if not agent.moving:
-                    prediction[index] = [index, agent.position[0], agent.position[1], agent.direction,
-                                         RailEnvActions.STOP_MOVING]
+                    prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
                     continue
                 # Take shortest possible path
                 cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
@@ -130,24 +138,25 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                     new_position = self._new_position(agent.position, new_direction)
                 elif np.sum(cell_transitions) > 1:
                     min_dist = np.inf
-                    for direct in range(4):
-                        if cell_transitions[direct] == 1:
-                            target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
+                    for direction in range(4):
+                        if cell_transitions[direction] == 1:
+                            target_dist = distance_map[agent.handle, agent.position[0], agent.position[1], direction]
                             if target_dist < min_dist:
                                 min_dist = target_dist
-                                new_direction = direct
+                                new_direction = direction
                     new_position = self._new_position(agent.position, new_direction)
 
                 agent.position = new_position
                 agent.direction = new_direction
-                prediction[index] = [index, new_position[0], new_position[1], new_direction, 0]
+                prediction[index] = [index, *new_position, new_direction, RailEnvActions.MOVE_FORWARD]
                 action_done = True
                 if not action_done:
                     raise Exception("Cannot move further. Something is wrong")
             prediction_dict[agent.handle] = prediction
+
+            # cleanup: reset initial position
             agent.position = _agent_initial_position
             agent.direction = _agent_initial_direction
-            agent_idx += 1
 
         return prediction_dict
 
diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py
index 2e86477b..ce224736 100644
--- a/tests/test_env_observation_builder.py
+++ b/tests/test_env_observation_builder.py
@@ -80,11 +80,3 @@ def test_global_obs():
     # If this assertion is wrong, it means that the observation returned
     # places the agent on an empty cell
     assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)
-
-
-def main():
-    test_global_obs()
-
-
-if __name__ == "__main__":
-    main()
diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py
index 5f5cea35..4d4078c3 100644
--- a/tests/test_env_prediction_builder.py
+++ b/tests/test_env_prediction_builder.py
@@ -4,15 +4,18 @@
 import numpy as np
 
 from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
+from flatland.core.transitions import Grid4TransitionsEnum
 from flatland.envs.generators import rail_from_GridTransitionMap_generator
 from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import DummyPredictorForRailEnv
+from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_env import RailEnvActions
+from flatland.utils.rendertools import RenderTool
 
 """Test predictions for `flatland` package."""
 
 
-def test_predictions():
+def make_simple_rail():
     # We instantiate a very simple rail network on a 7x10 grid:
     #        |
     #        |
@@ -22,7 +25,6 @@ def test_predictions():
     #                |
     #                |
     #                |
-
     cells = [int('0000000000000000', 2),  # empty cell - Case 0
              int('1000000000100000', 2),  # Case 1 - straight
              int('1001001000100000', 2),  # Case 2 - simple switch
@@ -31,22 +33,17 @@ def test_predictions():
              int('1100110000110011', 2),  # Case 5 - double slip switch
              int('0101001000000010', 2),  # Case 6 - symmetrical switch
              int('0010000000000000', 2)]  # Case 7 - dead end
-
     transitions = Grid4Transitions([])
     empty = cells[0]
-
     dead_end_from_south = cells[7]
     dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
     dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
     dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
-
     vertical_straight = cells[1]
     horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
-
     double_switch_south_horizontal_straight = horizontal_straight + cells[6]
     double_switch_north_horizontal_straight = transitions.rotate_transition(
         double_switch_south_horizontal_straight, 180)
-
     rail_map = np.array(
         [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
         [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
@@ -56,26 +53,36 @@ def test_predictions():
          [horizontal_straight] * 2 + [dead_end_from_west]] +
         [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
         [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
-
     rail = GridTransitionMap(width=rail_map.shape[1],
                              height=rail_map.shape[0], transitions=transitions)
     rail.grid = rail_map
+    return rail, rail_map
+
+
+def test_dummy_predictor(rendering=False):
+    rail, rail_map = make_simple_rail()
+
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_GridTransitionMap_generator(rail),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
                   )
-
     env.reset()
 
     # set initial position and direction for testing...
     env.agents[0].position = (5, 6)
     env.agents[0].direction = 0
-    env.agents[0].target = (3., 0.)
+    env.agents[0].target = (3, 0)
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.renderEnv(show=True, show_observations=False)
+        input("Continue?")
 
+    # test assertions
     predictions = env.obs_builder.predictor.get(None)
-    positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
+    positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
     directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
     time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
     actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
@@ -139,9 +146,149 @@ def test_predictions():
     assert np.array_equal(actions, expected_actions)
 
 
-def main():
-    test_predictions()
+def test_shortest_path_predictor(rendering=False):
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+    env.reset()
+
+    agent = env.agents[0]
+    agent.position = (5, 6)  # south dead-end
+    agent.direction = 0  # north
+    agent.target = (3, 9)  # east dead-end
+
+    agent.moving = True
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.renderEnv(show=True, show_observations=False)
+        # input("Continue?")
+
+    agent = env.agents[0]
+    assert agent.position == (5, 6)
+    assert agent.direction == 0
+    assert agent.target == (3, 9)
+    assert agent.moving
+
+    env.obs_builder._compute_distance_map()
+
+    distance_map = env.obs_builder.distance_map
+    assert distance_map[agent.handle, agent.position[0], agent.position[
+        1], agent.direction] == 5.0, "found {} instead of {}".format(
+        distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
+
+    # test assertions
+    env.obs_builder.get_many()
+    predictions = env.obs_builder.predictions
+    positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
+    directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
+    time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
+    actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
+
+    expected_positions = [
+        [5, 6],
+        [4, 6],
+        [3, 6],
+        [3, 7],
+        [3, 8],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+    ]
+    expected_directions = [
+        [Grid4TransitionsEnum.NORTH],  # next is [5,6] heading north
+        [Grid4TransitionsEnum.NORTH],  # next is [4,6] heading north
+        [Grid4TransitionsEnum.NORTH],  # next is [3,6] heading north
+        [Grid4TransitionsEnum.EAST],  # next is [3,7] heading east
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+    ]
+
+    expected_time_offsets = np.array([
+        [0.],
+        [1.],
+        [2.],
+        [3.],
+        [4.],
+        [5.],
+        [6.],
+        [7.],
+        [8.],
+        [9.],
+        [10.],
+        [11.],
+        [12.],
+        [13.],
+        [14.],
+        [15.],
+        [16.],
+        [17.],
+        [18.],
+        [19.],
+        [20.],
+    ])
 
+    expected_actions = np.array([
+        [RailEnvActions.DO_NOTHING],  # next [5,6]
+        [RailEnvActions.MOVE_FORWARD],  # next [4,6]
+        [RailEnvActions.MOVE_FORWARD],  # next [3,6]
+        [RailEnvActions.MOVE_RIGHT],  # next [3,7]
+        [RailEnvActions.MOVE_FORWARD],  # next [3,8]
+        [RailEnvActions.MOVE_FORWARD],  # next [3,9]
+        [RailEnvActions.STOP_MOVING],  # at [3,9] == target
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+    ])
 
-if __name__ == "__main__":
-    main()
+    assert np.array_equal(positions, expected_positions), \
+        "positions {}, expected {}".format(positions, expected_positions)
+    assert np.array_equal(directions, expected_directions), \
+        "directions {}, expected {}".format(directions, expected_directions)
+    assert np.array_equal(time_offsets, expected_time_offsets), \
+        "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
+    assert np.array_equal(actions, expected_actions), \
+        "actions {}, expected {}".format(actions, expected_actions)
diff --git a/tests/test_environments.py b/tests/test_environments.py
index 2131e08b..11f0acba 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -204,8 +204,3 @@ def test_dead_end():
 
     rail_env.reset()
     rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)]
-
-
-if __name__ == "__main__":
-    test_rail_environment_single_agent()
-    test_dead_end()
diff --git a/tests/test_player.py b/tests/test_player.py
index 21ff62c3..757fc90d 100644
--- a/tests/test_player.py
+++ b/tests/test_player.py
@@ -4,7 +4,3 @@ from examples.play_model import main
 def test_main():
     main(render=True, n_steps=20, n_trials=2, sGL="PIL")
     main(render=True, n_steps=20, n_trials=2, sGL="PILSVG")
-
-
-if __name__ == "__main__":
-    test_main()
-- 
GitLab