From 960361f1434461bfc3e5f443157d74ed2673ee1f Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Tue, 4 Jun 2019 11:38:20 +0200
Subject: [PATCH] 25 predictor draft

---
 flatland/core/env_prediction_builder.py |  11 ++-
 flatland/envs/predictions.py            |  61 ++++++++++++--
 flatland/envs/rail_env.py               |  88 ++++++++++----------
 tests/test_env_prediction_builder.py    | 102 +++++++++++++++++++++++-
 4 files changed, 201 insertions(+), 61 deletions(-)

diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index 321ede95..9f5e4dc5 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -15,8 +15,8 @@ class PredictionBuilder:
     PredictionBuilder base class.
     """
 
-    def __init__(self):
-        pass
+    def __init__(self, max_depth: int = 20):
+        self.max_depth = max_depth
 
     def _set_env(self, env):
         self.env = env
@@ -25,12 +25,11 @@ class PredictionBuilder:
         """
         Called after each environment reset.
         """
-        raise NotImplementedError()
+        pass
 
     def get(self, handle=0):
         """
-        Called whenever an observation has to be computed for the `env' environment, possibly
-        for each agent independently (agent id `handle').
+        Called whenever step_prediction is called on the environment.
 
         Parameters
         -------
@@ -40,6 +39,6 @@ class PredictionBuilder:
         Returns
         -------
         function
-            An prediction structure, specific to the corresponding environment.
+            A prediction structure, specific to the corresponding environment.
         """
         raise NotImplementedError()
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index e57bf753..95c1a984 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -2,6 +2,8 @@
 Collection of environment-specific PredictionBuilder.
 """
 
+import numpy as np
+
 from flatland.core.env_prediction_builder import PredictionBuilder
 
 
@@ -13,11 +15,58 @@ class DummyPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def __init__(self):
-        pass
+    def get(self, handle=None):
+        """
+        Called whenever step_prediction is called on the environment.
+
+        Parameters
+        -------
+        handle : int (optional)
+            Handle of the agent for which to compute the observation vector.
+
+        Returns
+        -------
+        function
+            Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
+            - time_offset
+            - position axis 0
+            - position axis 1
+            - direction
+            - action taken to come here
+        """
+        agents = self.env.agents
+        if handle:
+            agents = [self.env.agents[handle]]
+
+        prediction_dict = {}
 
-    def reset(self):
-        pass
+        for agent in agents:
 
-    def get(self, handle=0):
-        return {}
+            # 0: do nothing
+            # 1: turn left and move to the next cell
+            # 2: move to the next cell in front of the agent
+            # 3: turn right and move to the next cell
+            action_priorities = [2, 1, 3]
+            _agent_initial_position = agent.position
+            _agent_initial_direction = agent.direction
+            prediction = np.zeros(shape=(self.max_depth, 5))
+            prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
+            for index in range(1, self.max_depth):
+                action_done = False
+                for action in action_priorities:
+                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self.env._check_action_on_agent(action,
+                                                                                                                                     agent)
+                    if all([new_cell_isValid, transition_isValid]):
+                        # move and change direction to face the new_direction that was
+                        # performed
+                        agent.position = new_position
+                        agent.direction = new_direction
+                        prediction[index] = [index, new_position[0], new_position[1], new_direction, action]
+                        action_done = True
+                        break
+                if not action_done:
+                    print("Cannot move further.")
+            prediction_dict[agent.handle] = prediction
+            agent.position = _agent_initial_position
+            agent.direction = _agent_initial_direction
+        return prediction_dict
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 6ac31878..190a14ea 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -219,51 +219,9 @@ class RailEnv(Environment):
                 return
 
             if action > 0:
-                # pos = agent.position #  self.agents_position[i]
-                # direction = agent.direction # self.agents_direction[i]
-
-                # compute number of possible transitions in the current
-                # cell used to check for invalid actions
-
-                new_direction, transition_isValid = self.check_action(agent, action)
-
-                new_position = get_new_position(agent.position, new_direction)
-                # Is it a legal move?
-                # 1) transition allows the new_direction in the cell,
-                # 2) the new cell is not empty (case 0),
-                # 3) the cell is free, i.e., no agent is currently in that cell
-
-                # if (
-                #        new_position[1] >= self.width or
-                #        new_position[0] >= self.height or
-                #        new_position[0] < 0 or new_position[1] < 0):
-                #    new_cell_isValid = False
-
-                # if self.rail.get_transitions(new_position) == 0:
-                #     new_cell_isValid = False
-
-                new_cell_isValid = (
-                    np.array_equal(  # Check the new position is still in the grid
-                        new_position,
-                        np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
-                    and  # check the new position has some transitions (ie is not an empty cell)
-                    self.rail.get_transitions(new_position) > 0)
-
-                # If transition validity hasn't been checked yet.
-                if transition_isValid is None:
-                    transition_isValid = self.rail.get_transition(
-                        (*agent.position, agent.direction),
-                        new_direction)
-
-                # cell_isFree = True
-                # for j in range(self.number_of_agents):
-                #    if self.agents_position[j] == new_position:
-                #        cell_isFree = False
-                #        break
-                # Check the new position is not the same as any of the existing agent positions
-                # (including itself, for simplicity, since it is moving)
-                cell_isFree = not np.any(
-                    np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
+                cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self._check_action_on_agent(action,
+                                                                                                                             agent,
+                                                                                                                             transition_isValid)
 
                 if all([new_cell_isValid, transition_isValid, cell_isFree]):
                     # move and change direction to face the new_direction that was
@@ -303,6 +261,46 @@ class RailEnv(Environment):
         self.actions = [0] * self.get_num_agents()
         return self._get_observations(), self.rewards_dict, self.dones, {}
 
+    def _check_action_on_agent(self, action, agent):
+        # pos = agent.position #  self.agents_position[i]
+        # direction = agent.direction # self.agents_direction[i]
+        # compute number of possible transitions in the current
+        # cell used to check for invalid actions
+        new_direction, transition_isValid = self.check_action(agent, action)
+        new_position = get_new_position(agent.position, new_direction)
+        # Is it a legal move?
+        # 1) transition allows the new_direction in the cell,
+        # 2) the new cell is not empty (case 0),
+        # 3) the cell is free, i.e., no agent is currently in that cell
+        # if (
+        #        new_position[1] >= self.width or
+        #        new_position[0] >= self.height or
+        #        new_position[0] < 0 or new_position[1] < 0):
+        #    new_cell_isValid = False
+        # if self.rail.get_transitions(new_position) == 0:
+        #     new_cell_isValid = False
+        new_cell_isValid = (
+            np.array_equal(  # Check the new position is still in the grid
+                new_position,
+                np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
+            and  # check the new position has some transitions (ie is not an empty cell)
+            self.rail.get_transitions(new_position) > 0)
+        # If transition validity hasn't been checked yet.
+        if transition_isValid is None:
+            transition_isValid = self.rail.get_transition(
+                (*agent.position, agent.direction),
+                new_direction)
+        # cell_isFree = True
+        # for j in range(self.number_of_agents):
+        #    if self.agents_position[j] == new_position:
+        #        cell_isFree = False
+        #        break
+        # Check the new position is not the same as any of the existing agent positions
+        # (including itself, for simplicity, since it is moving)
+        cell_isFree = not np.any(
+            np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
+        return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
+
     def predict(self):
         if not self.prediction_builder:
             return {}
diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py
index a43f7883..35a6a27b 100644
--- a/tests/test_env_prediction_builder.py
+++ b/tests/test_env_prediction_builder.py
@@ -3,13 +3,13 @@
 
 import numpy as np
 
-from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
+from flatland.envs.generators import rail_from_GridTransitionMap_generator
+from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
-from flatland.envs.generators import rail_from_GridTransitionMap_generator
 
-"""Tests for `flatland` package."""
+"""Test predictions for `flatland` package."""
 
 
 def test_predictions():
@@ -65,12 +65,106 @@ def test_predictions():
                   rail_generator=rail_from_GridTransitionMap_generator(rail),
                   number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv(),
-                  prediction_builder_object=DummyPredictorForRailEnv()
+                  prediction_builder_object=DummyPredictorForRailEnv(max_depth=20)
                   )
 
     env.reset()
 
+    # set initial position and direction for testing...
+    env.agents[0].position = (5, 6)
+    env.agents[0].direction = 0
+
     predictions = env.predict()
+    positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
+    directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
+    time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
+    actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
+
+    # compare against expected values
+    expected_positions = np.array([[5., 6.],
+                                   [4., 6.],
+                                   [3., 6.],
+                                   [3., 5.],
+                                   [3., 4.],
+                                   [3., 3.],
+                                   [3., 2.],
+                                   [3., 1.],
+                                   [3., 0.],
+                                   [3., 1.],
+                                   [3., 2.],
+                                   [3., 3.],
+                                   [3., 4.],
+                                   [3., 5.],
+                                   [3., 6.],
+                                   [3., 7.],
+                                   [3., 8.],
+                                   [3., 9.],
+                                   [3., 8.],
+                                   [3., 7.]])
+    expected_directions = np.array([[0.],
+                                    [0.],
+                                    [0.],
+                                    [3.],
+                                    [3.],
+                                    [3.],
+                                    [3.],
+                                    [3.],
+                                    [3.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [3.],
+                                    [3.]])
+    expected_time_offsets = np.array([[0.],
+                                      [1.],
+                                      [2.],
+                                      [3.],
+                                      [4.],
+                                      [5.],
+                                      [6.],
+                                      [7.],
+                                      [8.],
+                                      [9.],
+                                      [10.],
+                                      [11.],
+                                      [12.],
+                                      [13.],
+                                      [14.],
+                                      [15.],
+                                      [16.],
+                                      [17.],
+                                      [18.],
+                                      [19.]])
+    expected_actions = np.array([[0.],
+                                 [2.],
+                                 [2.],
+                                 [1.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.]])
+    assert np.array_equal(positions, expected_positions)
+    assert np.array_equal(directions, expected_directions)
+    assert np.array_equal(time_offsets, expected_time_offsets)
+    assert np.array_equal(actions, expected_actions)
 
 
 def main():
-- 
GitLab