From f6e81e1a9cb0ac7c5e774022af7f1a02d1e8c13d Mon Sep 17 00:00:00 2001
From: Christian Eichenberger <christian.markus.eichenberger@sbb.ch>
Date: Mon, 17 Jun 2019 14:33:42 +0000
Subject: [PATCH] Resolve "shortest-path-predictor"

---
 flatland/core/env_prediction_builder.py |   8 +-
 flatland/core/transitions.py            |  28 ++++
 flatland/envs/env_utils.py              |  11 +-
 flatland/envs/observations.py           |  37 +++--
 flatland/envs/predictions.py            |  97 +++++++------
 tests/test_env_observation_builder.py   |   8 --
 tests/test_env_prediction_builder.py    | 179 +++++++++++++++++++++---
 tests/test_environments.py              |   5 -
 tests/test_player.py                    |   4 -
 9 files changed, 275 insertions(+), 102 deletions(-)

diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index 060dbfc3..5ce69a81 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -13,6 +13,7 @@ case of multi-agent environments.
 class PredictionBuilder:
     """
     PredictionBuilder base class.
+
     """
 
     def __init__(self, max_depth: int = 20):
@@ -27,12 +28,15 @@ class PredictionBuilder:
         """
         pass
 
-    def get(self, handle=0):
+    def get(self, custom_args=None, handle=0):
         """
-        Called whenever predict is called on the environment.
+        Called whenever get_many in the observation build is called.
 
         Parameters
         -------
+        custom_args: dict
+            Implementation-dependent custom arguments, see the sub-classes.
+
         handle : int (optional)
             Handle of the agent for which to compute the observation vector.
 
diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index a6d1bb07..6c38a39c 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -3,6 +3,7 @@ The transitions module defines the base Transitions class and a
 derived GridTransitions class, which allows for the specification of
 possible transitions over a 2D grid.
 """
+from enum import IntEnum
 
 import numpy as np
 
@@ -129,6 +130,16 @@ class Transitions:
         """
         raise NotImplementedError()
 
+    def get_direction_enum(self) -> IntEnum:
+        raise NotImplementedError()
+
+
+class Grid4TransitionsEnum(IntEnum):
+    NORTH = 0
+    EAST = 1
+    SOUTH = 2
+    WEST = 3
+
 
 class Grid4Transitions(Transitions):
     """
@@ -323,6 +334,20 @@ class Grid4Transitions(Transitions):
         cell_transition = value
         return cell_transition
 
+    def get_direction_enum(self) -> IntEnum:
+        return Grid4TransitionsEnum
+
+
+class Grid8TransitionsEnum(IntEnum):
+    NORTH = 0
+    NORTH_EAST = 1
+    EAST = 2
+    SOUTH_EAST = 3
+    SOUTH = 4
+    SOUTH_WEST = 5
+    WEST = 6
+    NORTH_WEST = 7
+
 
 class Grid8Transitions(Transitions):
     """
@@ -504,6 +529,9 @@ class Grid8Transitions(Transitions):
 
         return cell_transition
 
+    def get_direction_enum(self) -> IntEnum:
+        return Grid8TransitionsEnum
+
 
 class RailEnvTransitions(Grid4Transitions):
     """
diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index c9595b76..ee2c2637 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -7,6 +7,8 @@ a GridTransitionMap object.
 
 import numpy as np
 
+from flatland.core.transitions import Grid4TransitionsEnum
+
 
 def get_direction(pos1, pos2):
     """
@@ -253,13 +255,14 @@ def distance_on_rail(pos1, pos2):
 
 
 def get_new_position(position, movement):
-    if movement == 0:  # NORTH
+    """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
+    if movement == Grid4TransitionsEnum.NORTH:
         return (position[0] - 1, position[1])
-    elif movement == 1:  # EAST
+    elif movement == Grid4TransitionsEnum.EAST:
         return (position[0], position[1] + 1)
-    elif movement == 2:  # SOUTH
+    elif movement == Grid4TransitionsEnum.SOUTH:
         return (position[0] + 1, position[1])
-    elif movement == 3:  # WEST
+    elif movement == Grid4TransitionsEnum.WEST:
         return (position[0], position[1] - 1)
 
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 4b0049f6..a7f91f14 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -6,6 +6,7 @@ from collections import deque
 import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.transitions import Grid4TransitionsEnum
 from flatland.envs.env_utils import coordinate_to_position
 
 
@@ -48,16 +49,19 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.agents_previous_reset = agents
 
         if compute_distance_map:
-            self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
-                                                        self.env.height,
-                                                        self.env.width,
-                                                        4))
-            self.max_dist = np.zeros(nAgents)
+            self._compute_distance_map()
 
-            self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
-
-            # Update local lookup table for all agents' target locations
-            self.location_has_target = {tuple(agent.target): 1 for agent in agents}
+    def _compute_distance_map(self):
+        agents = self.env.agents
+        nAgents = len(agents)
+        self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
+                                                    self.env.height,
+                                                    self.env.width,
+                                                    4))
+        self.max_dist = np.zeros(nAgents)
+        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
+        # Update local lookup table for all agents' target locations
+        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
 
     def _distance_map_walker(self, position, target_nr):
         """
@@ -159,13 +163,13 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
         Utility function that converts a compass movement over a 2D grid to new positions (r, c).
         """
-        if movement == 0:  # NORTH
+        if movement == Grid4TransitionsEnum.NORTH:
             return (position[0] - 1, position[1])
-        elif movement == 1:  # EAST
+        elif movement == Grid4TransitionsEnum.EAST:
             return (position[0], position[1] + 1)
-        elif movement == 2:  # SOUTH
+        elif movement == Grid4TransitionsEnum.SOUTH:
             return (position[0] + 1, position[1])
-        elif movement == 3:  # WEST
+        elif movement == Grid4TransitionsEnum.WEST:
             return (position[0], position[1] - 1)
 
     def get_many(self, handles=[]):
@@ -177,7 +181,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         if self.predictor:
             self.predicted_pos = {}
             self.predicted_dir = {}
-            self.predictions = self.predictor.get(self.distance_map)
+            self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
             for t in range(len(self.predictions[0])):
                 pos_list = []
                 dir_list = []
@@ -796,8 +800,3 @@ class LocalObsForRailEnv(ObservationBuilder):
         direction = self._get_one_hot_for_agent_direction(agent)
 
         return local_rail_obs, obs_map_state, obs_other_agents_state, direction
-
-# class LocalObsForRailEnvImproved(ObservationBuilder):
-#     """
-#     Returns a local observation around the given agent
-#     """
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index b6fe8631..3fda7378 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder.
 import numpy as np
 
 from flatland.core.env_prediction_builder import PredictionBuilder
+from flatland.envs.env_utils import get_new_position
 from flatland.envs.rail_env import RailEnvActions
 
 
@@ -16,24 +17,28 @@ class DummyPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def get(self, distancemap, handle=None):
+    def get(self, custom_args=None, handle=None):
         """
-        Called whenever predict is called on the environment.
+        Called whenever get_many in the observation build is called.
 
         Parameters
         -------
+        custom_args: dict
+            Not used in this dummy implementation.
         handle : int (optional)
             Handle of the agent for which to compute the observation vector.
 
         Returns
         -------
-        function
-            Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
+        np.array
+            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
             - time_offset
             - position axis 0
             - position axis 1
             - direction
             - action taken to come here
+            The prediction at 0 is the current position, direction etc.
+
         """
         agents = self.env.agents
         if handle:
@@ -46,13 +51,12 @@ class DummyPredictorForRailEnv(PredictionBuilder):
             _agent_initial_position = agent.position
             _agent_initial_direction = agent.direction
             prediction = np.zeros(shape=(self.max_depth + 1, 5))
-            prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
+            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
             for index in range(1, self.max_depth + 1):
                 action_done = False
                 # if we're at the target, stop moving...
                 if agent.position == agent.target:
-                    prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
-                                         RailEnvActions.STOP_MOVING]
+                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
 
                     continue
                 for action in action_priorities:
@@ -63,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
                         # performed
                         agent.position = new_position
                         agent.direction = new_direction
-                        prediction[index] = [index, new_position[0], new_position[1], new_direction, action]
+                        prediction[index] = [index, *new_position, new_direction, action]
                         action_done = True
                         break
                 if not action_done:
@@ -76,90 +80,95 @@ class DummyPredictorForRailEnv(PredictionBuilder):
 
 class ShortestPathPredictorForRailEnv(PredictionBuilder):
     """
-    DummyPredictorForRailEnv object.
+    ShortestPathPredictorForRailEnv object.
 
-    This object returns predictions for agents in the RailEnv environment.
+    This object returns shortest-path predictions for agents in the RailEnv environment.
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def get(self, distancemap, handle=None):
+    def get(self, custom_args=None, handle=None):
         """
-        Called whenever predict is called on the environment.
+        Called whenever get_many in the observation build is called.
+        Requires distance_map to extract the shortest path.
 
         Parameters
         -------
+        custom_args: dict
+            - distance_map : dict
         handle : int (optional)
             Handle of the agent for which to compute the observation vector.
 
         Returns
         -------
-        function
-            Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
+        np.array
+            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
             - time_offset
             - position axis 0
             - position axis 1
             - direction
             - action taken to come here
+            The prediction at 0 is the current position, direction etc.
         """
         agents = self.env.agents
         if handle:
             agents = [self.env.agents[handle]]
+        assert custom_args
+        distance_map = custom_args.get('distance_map')
+        assert distance_map is not None
 
         prediction_dict = {}
-        agent_idx = 0
         for agent in agents:
             _agent_initial_position = agent.position
             _agent_initial_direction = agent.direction
             prediction = np.zeros(shape=(self.max_depth + 1, 5))
-            prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
+            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
             for index in range(1, self.max_depth + 1):
                 # if we're at the target, stop moving...
                 if agent.position == agent.target:
-                    prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
-                                         RailEnvActions.STOP_MOVING]
+                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
                     continue
                 if not agent.moving:
-                    prediction[index] = [index, agent.position[0], agent.position[1], agent.direction,
-                                         RailEnvActions.STOP_MOVING]
+                    prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
                     continue
                 # Take shortest possible path
                 cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
 
+                new_position = None
+                new_direction = None
                 if np.sum(cell_transitions) == 1:
                     new_direction = np.argmax(cell_transitions)
-                    new_position = self._new_position(agent.position, new_direction)
+                    new_position = get_new_position(agent.position, new_direction)
                 elif np.sum(cell_transitions) > 1:
                     min_dist = np.inf
-                    for direct in range(4):
-                        if cell_transitions[direct] == 1:
-                            target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
+                    for direction in range(4):
+                        if cell_transitions[direction] == 1:
+                            target_dist = distance_map[agent.handle, agent.position[0], agent.position[1], direction]
                             if target_dist < min_dist:
                                 min_dist = target_dist
-                                new_direction = direct
-                    new_position = self._new_position(agent.position, new_direction)
+                                new_direction = direction
+                    new_position = get_new_position(agent.position, new_direction)
+                else:
+                    raise Exception("No transition possible {}".format(cell_transitions))
+
+                # which action to take for the transition?
+                action = None
+                for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]:
+                    _, _, _new_direction, _new_position, _ = self.env._check_action_on_agent(_action, agent)
+                    if np.array_equal(_new_position, new_position):
+                        action = _action
+                        break
+                assert action is not None
 
+                # update the agent's position and direction
                 agent.position = new_position
                 agent.direction = new_direction
-                prediction[index] = [index, new_position[0], new_position[1], new_direction, 0]
-                action_done = True
-                if not action_done:
-                    raise Exception("Cannot move further. Something is wrong")
+
+                # prediction is ready
+                prediction[index] = [index, *new_position, new_direction, action]
             prediction_dict[agent.handle] = prediction
+
+            # cleanup: reset initial position
             agent.position = _agent_initial_position
             agent.direction = _agent_initial_direction
-            agent_idx += 1
 
         return prediction_dict
-
-    def _new_position(self, position, movement):
-        """
-        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
-        """
-        if movement == 0:  # NORTH
-            return (position[0] - 1, position[1])
-        elif movement == 1:  # EAST
-            return (position[0], position[1] + 1)
-        elif movement == 2:  # SOUTH
-            return (position[0] + 1, position[1])
-        elif movement == 3:  # WEST
-            return (position[0], position[1] - 1)
diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py
index 2e86477b..ce224736 100644
--- a/tests/test_env_observation_builder.py
+++ b/tests/test_env_observation_builder.py
@@ -80,11 +80,3 @@ def test_global_obs():
     # If this assertion is wrong, it means that the observation returned
     # places the agent on an empty cell
     assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)
-
-
-def main():
-    test_global_obs()
-
-
-if __name__ == "__main__":
-    main()
diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py
index 5f5cea35..f34829d7 100644
--- a/tests/test_env_prediction_builder.py
+++ b/tests/test_env_prediction_builder.py
@@ -4,15 +4,18 @@
 import numpy as np
 
 from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
+from flatland.core.transitions import Grid4TransitionsEnum
 from flatland.envs.generators import rail_from_GridTransitionMap_generator
 from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import DummyPredictorForRailEnv
+from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_env import RailEnvActions
+from flatland.utils.rendertools import RenderTool
 
 """Test predictions for `flatland` package."""
 
 
-def test_predictions():
+def make_simple_rail():
     # We instantiate a very simple rail network on a 7x10 grid:
     #        |
     #        |
@@ -22,7 +25,6 @@ def test_predictions():
     #                |
     #                |
     #                |
-
     cells = [int('0000000000000000', 2),  # empty cell - Case 0
              int('1000000000100000', 2),  # Case 1 - straight
              int('1001001000100000', 2),  # Case 2 - simple switch
@@ -31,22 +33,17 @@ def test_predictions():
              int('1100110000110011', 2),  # Case 5 - double slip switch
              int('0101001000000010', 2),  # Case 6 - symmetrical switch
              int('0010000000000000', 2)]  # Case 7 - dead end
-
     transitions = Grid4Transitions([])
     empty = cells[0]
-
     dead_end_from_south = cells[7]
     dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
     dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
     dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
-
     vertical_straight = cells[1]
     horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
-
     double_switch_south_horizontal_straight = horizontal_straight + cells[6]
     double_switch_north_horizontal_straight = transitions.rotate_transition(
         double_switch_south_horizontal_straight, 180)
-
     rail_map = np.array(
         [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
         [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
@@ -56,26 +53,36 @@ def test_predictions():
          [horizontal_straight] * 2 + [dead_end_from_west]] +
         [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
         [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
-
     rail = GridTransitionMap(width=rail_map.shape[1],
                              height=rail_map.shape[0], transitions=transitions)
     rail.grid = rail_map
+    return rail, rail_map
+
+
+def test_dummy_predictor(rendering=False):
+    rail, rail_map = make_simple_rail()
+
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_GridTransitionMap_generator(rail),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
                   )
-
     env.reset()
 
     # set initial position and direction for testing...
     env.agents[0].position = (5, 6)
     env.agents[0].direction = 0
-    env.agents[0].target = (3., 0.)
+    env.agents[0].target = (3, 0)
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.renderEnv(show=True, show_observations=False)
+        input("Continue?")
 
+    # test assertions
     predictions = env.obs_builder.predictor.get(None)
-    positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
+    positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
     directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
     time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
     actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
@@ -139,9 +146,149 @@ def test_predictions():
     assert np.array_equal(actions, expected_actions)
 
 
-def main():
-    test_predictions()
+def test_shortest_path_predictor(rendering=False):
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+    env.reset()
+
+    agent = env.agents[0]
+    agent.position = (5, 6)  # south dead-end
+    agent.direction = 0  # north
+    agent.target = (3, 9)  # east dead-end
+
+    agent.moving = True
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.renderEnv(show=True, show_observations=False)
+        input("Continue?")
+
+    agent = env.agents[0]
+    assert agent.position == (5, 6)
+    assert agent.direction == 0
+    assert agent.target == (3, 9)
+    assert agent.moving
+
+    env.obs_builder._compute_distance_map()
+
+    distance_map = env.obs_builder.distance_map
+    assert distance_map[agent.handle, agent.position[0], agent.position[
+        1], agent.direction] == 5.0, "found {} instead of {}".format(
+        distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
+
+    # test assertions
+    env.obs_builder.get_many()
+    predictions = env.obs_builder.predictions
+    positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
+    directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
+    time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
+    actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
+
+    expected_positions = [
+        [5, 6],
+        [4, 6],
+        [3, 6],
+        [3, 7],
+        [3, 8],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+        [3, 9],
+    ]
+    expected_directions = [
+        [Grid4TransitionsEnum.NORTH],  # next is [5,6] heading north
+        [Grid4TransitionsEnum.NORTH],  # next is [4,6] heading north
+        [Grid4TransitionsEnum.NORTH],  # next is [3,6] heading north
+        [Grid4TransitionsEnum.EAST],  # next is [3,7] heading east
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+        [Grid4TransitionsEnum.EAST],
+    ]
+
+    expected_time_offsets = np.array([
+        [0.],
+        [1.],
+        [2.],
+        [3.],
+        [4.],
+        [5.],
+        [6.],
+        [7.],
+        [8.],
+        [9.],
+        [10.],
+        [11.],
+        [12.],
+        [13.],
+        [14.],
+        [15.],
+        [16.],
+        [17.],
+        [18.],
+        [19.],
+        [20.],
+    ])
 
+    expected_actions = np.array([
+        [RailEnvActions.DO_NOTHING],  # next [5,6]
+        [RailEnvActions.MOVE_FORWARD],  # next [4,6]
+        [RailEnvActions.MOVE_FORWARD],  # next [3,6]
+        [RailEnvActions.MOVE_RIGHT],  # next [3,7]
+        [RailEnvActions.MOVE_FORWARD],  # next [3,8]
+        [RailEnvActions.MOVE_FORWARD],  # next [3,9]
+        [RailEnvActions.STOP_MOVING],  # at [3,9] == target
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+        [RailEnvActions.STOP_MOVING],
+    ])
 
-if __name__ == "__main__":
-    main()
+    assert np.array_equal(positions, expected_positions), \
+        "positions {}, expected {}".format(positions, expected_positions)
+    assert np.array_equal(directions, expected_directions), \
+        "directions {}, expected {}".format(directions, expected_directions)
+    assert np.array_equal(time_offsets, expected_time_offsets), \
+        "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
+    assert np.array_equal(actions, expected_actions), \
+        "actions {}, expected {}".format(actions, expected_actions)
diff --git a/tests/test_environments.py b/tests/test_environments.py
index 2131e08b..11f0acba 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -204,8 +204,3 @@ def test_dead_end():
 
     rail_env.reset()
     rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)]
-
-
-if __name__ == "__main__":
-    test_rail_environment_single_agent()
-    test_dead_end()
diff --git a/tests/test_player.py b/tests/test_player.py
index 21ff62c3..757fc90d 100644
--- a/tests/test_player.py
+++ b/tests/test_player.py
@@ -4,7 +4,3 @@ from examples.play_model import main
 def test_main():
     main(render=True, n_steps=20, n_trials=2, sGL="PIL")
     main(render=True, n_steps=20, n_trials=2, sGL="PILSVG")
-
-
-if __name__ == "__main__":
-    test_main()
-- 
GitLab