From 311b98141d20229c4c5d1dfd11bccdc44803c59f Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Mon, 3 Jun 2019 16:08:07 +0200
Subject: [PATCH] 25 skeleton

---
 flatland/core/env.py                     | 15 +++++
 flatland/core/env_observation_builder.py |  1 -
 flatland/core/env_prediction_builder.py  | 45 +++++++++++++
 flatland/envs/predictions.py             | 23 +++++++
 flatland/envs/rail_env.py                | 20 +++++-
 tests/test_env_prediction_builder.py     | 81 ++++++++++++++++++++++++
 6 files changed, 183 insertions(+), 2 deletions(-)
 create mode 100644 flatland/core/env_prediction_builder.py
 create mode 100644 flatland/envs/predictions.py
 create mode 100644 tests/test_env_prediction_builder.py

diff --git a/flatland/core/env.py b/flatland/core/env.py
index 3c25beea..32691f50 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -84,6 +84,21 @@ class Environment:
         """
         raise NotImplementedError()
 
+    def predict(self):
+        """
+        Predictions step.
+
+        Returns predictions for the agents.
+        The returns are dicts mapping from agent_id strings to values.
+
+        Returns
+        -------
+        predictions : dict
+            New predictions for each ready agent.
+
+        """
+        raise NotImplementedError()
+
     def render(self):
         """
         Perform rendering of the environment.
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 3cef545c..f85afee4 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -19,7 +19,6 @@ class ObservationBuilder:
 
     def __init__(self):
         self.observation_space = ()
-        pass
 
     def _set_env(self, env):
         self.env = env
diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
new file mode 100644
index 00000000..321ede95
--- /dev/null
+++ b/flatland/core/env_prediction_builder.py
@@ -0,0 +1,45 @@
+"""
+PredictionBuilder objects are objects that can be passed to environments designed for customizability.
+The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]).
+If predictions are not required in every step or not for all agents, then
+
++ Reset() is called after each environment reset, to allow for pre-computing relevant data.
+
++ Get() is called whenever an step has to be computed, potentially for each agent independently in
+case of multi-agent environments.
+"""
+
+
+class PredictionBuilder:
+    """
+    PredictionBuilder base class.
+    """
+
+    def __init__(self):
+        pass
+
+    def _set_env(self, env):
+        self.env = env
+
+    def reset(self):
+        """
+        Called after each environment reset.
+        """
+        raise NotImplementedError()
+
+    def get(self, handle=0):
+        """
+        Called whenever an observation has to be computed for the `env' environment, possibly
+        for each agent independently (agent id `handle').
+
+        Parameters
+        -------
+        handle : int (optional)
+            Handle of the agent for which to compute the observation vector.
+
+        Returns
+        -------
+        function
+            An prediction structure, specific to the corresponding environment.
+        """
+        raise NotImplementedError()
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
new file mode 100644
index 00000000..e57bf753
--- /dev/null
+++ b/flatland/envs/predictions.py
@@ -0,0 +1,23 @@
+"""
+Collection of environment-specific PredictionBuilder.
+"""
+
+from flatland.core.env_prediction_builder import PredictionBuilder
+
+
+class DummyPredictorForRailEnv(PredictionBuilder):
+    """
+    DummyPredictorForRailEnv object.
+
+    This object returns predictions for agents in the RailEnv environment.
+    The prediction acts as if no other agent is in the environment and always takes the forward action.
+    """
+
+    def __init__(self):
+        pass
+
+    def reset(self):
+        pass
+
+    def get(self, handle=0):
+        return {}
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index a23b6ac7..6ac31878 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -51,7 +51,9 @@ class RailEnv(Environment):
                  height,
                  rail_generator=random_rail_generator(),
                  number_of_agents=1,
-                 obs_builder_object=TreeObsForRailEnv(max_depth=2)):
+                 obs_builder_object=TreeObsForRailEnv(max_depth=2),
+                 prediction_builder_object=None
+                 ):
         """
         Environment init.
 
@@ -94,6 +96,11 @@ class RailEnv(Environment):
         self.obs_builder = obs_builder_object
         self.obs_builder._set_env(self)
 
+        self.prediction_builder = prediction_builder_object
+        if self.prediction_builder:
+            self.prediction_builder._set_env(self)
+
+
         self.action_space = [1]
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
 
@@ -296,6 +303,12 @@ class RailEnv(Environment):
         self.actions = [0] * self.get_num_agents()
         return self._get_observations(), self.rewards_dict, self.dones, {}
 
+    def predict(self):
+        if not self.prediction_builder:
+            return {}
+        return  self.prediction_builder.get()
+
+
     def check_action(self, agent, action):
         transition_isValid = None
         possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
@@ -332,6 +345,11 @@ class RailEnv(Environment):
             self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
         return self.obs_dict
 
+    def _get_predictions(self):
+        if not self.prediction_builder:
+            return {}
+        return {}
+
     def render(self):
         # TODO:
         pass
diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py
new file mode 100644
index 00000000..a43f7883
--- /dev/null
+++ b/tests/test_env_prediction_builder.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import numpy as np
+
+from flatland.envs.observations import GlobalObsForRailEnv
+from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
+from flatland.envs.predictions import DummyPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.generators import rail_from_GridTransitionMap_generator
+
+"""Tests for `flatland` package."""
+
+
+def test_predictions():
+    # We instantiate a very simple rail network on a 7x10 grid:
+    #        |
+    #        |
+    #        |
+    # _ _ _ /_\ _ _  _  _ _ _
+    #               \ /
+    #                |
+    #                |
+    #                |
+
+    cells = [int('0000000000000000', 2),  # empty cell - Case 0
+             int('1000000000100000', 2),  # Case 1 - straight
+             int('1001001000100000', 2),  # Case 2 - simple switch
+             int('1000010000100001', 2),  # Case 3 - diamond drossing
+             int('1001011000100001', 2),  # Case 4 - single slip switch
+             int('1100110000110011', 2),  # Case 5 - double slip switch
+             int('0101001000000010', 2),  # Case 6 - symmetrical switch
+             int('0010000000000000', 2)]  # Case 7 - dead end
+
+    transitions = Grid4Transitions([])
+    empty = cells[0]
+
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+
+    double_switch_south_horizontal_straight = horizontal_straight + cells[6]
+    double_switch_north_horizontal_straight = transitions.rotate_transition(
+        double_switch_south_horizontal_straight, 180)
+
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [double_switch_north_horizontal_straight] +
+         [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=1,
+                  obs_builder_object=GlobalObsForRailEnv(),
+                  prediction_builder_object=DummyPredictorForRailEnv()
+                  )
+
+    env.reset()
+
+    predictions = env.predict()
+
+
+def main():
+    test_predictions()
+
+
+if __name__ == "__main__":
+    main()
-- 
GitLab