From 2be1a6c4db0587b0bfb191fe9f94da4a31825f86 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Thu, 13 Jun 2019 08:47:43 +0200
Subject: [PATCH] fix master

---
 flatland/core/env.py                 | 15 ---------------
 flatland/envs/observations.py        |  4 ----
 flatland/envs/rail_env.py            |  2 --
 flatland/utils/editor.py             |  3 ++-
 flatland/utils/graphics_pil.py       |  2 +-
 tests/test_env_prediction_builder.py |  7 +++----
 6 files changed, 6 insertions(+), 27 deletions(-)

diff --git a/flatland/core/env.py b/flatland/core/env.py
index 3618d965..1bc5b6f3 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -84,21 +84,6 @@ class Environment:
         """
         raise NotImplementedError()
 
-    def predict(self):
-        """
-        Predictions step.
-
-        Returns predictions for the agents.
-        The returns are dicts mapping from agent_id strings to values.
-
-        Returns
-        -------
-        predictions : dict
-            New predictions for each ready agent.
-
-        """
-        raise NotImplementedError()
-
     def get_agent_handles(self):
         """
         Returns a list of agents' handles to be used as keys in the step()
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index a3d88d77..76bed8a4 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -173,10 +173,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         Called whenever an observation has to be computed for the `env' environment, for each agent with handle
         in the `handles' list.
         """
-
-        # TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object.
-        if self.predictor:
-            print(self.predictor.get(0))
         observations = {}
         for h in handles:
             observations[h] = self.get(h)
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 5d20a5d9..7773f86c 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -292,7 +292,6 @@ class RailEnv(Environment):
             np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
         return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
 
-
     def check_action(self, agent, action):
         transition_isValid = None
         possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
@@ -324,7 +323,6 @@ class RailEnv(Environment):
         self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
         return self.obs_dict
 
-
     def get_full_state_msg(self):
         grid_data = self.rail.grid.tolist()
         agent_static_data = [agent.to_list() for agent in self.agents_static]
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index d4e5c38e..81565d62 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -323,7 +323,8 @@ class Controller(object):
     def restartAgents(self, event):
         self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value)
         if self.model.init_agents_static is not None:
-            self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in self.model.init_agents_static]
+            self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in
+                                            self.model.init_agents_static]
             self.model.env.agents = None
             self.model.init_agents_static = None
             self.player = None
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index bca964c9..0f67421e 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -396,7 +396,7 @@ class PILSVG(PILGL):
         }
 
         # "paint" color of the train images we load - this is the color we will change.
-        # a3BaseColor = self.rgb_s2i("0091ea")
+        # a3BaseColor = self.rgb_s2i("0091ea") \#  noqa: E800
         # temporary workaround for trains / agents renamed with different colour:
         a3BaseColor = self.rgb_s2i("d50000")
 
diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py
index 35a6a27b..be065d3d 100644
--- a/tests/test_env_prediction_builder.py
+++ b/tests/test_env_prediction_builder.py
@@ -5,7 +5,7 @@ import numpy as np
 
 from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
 from flatland.envs.generators import rail_from_GridTransitionMap_generator
-from flatland.envs.observations import GlobalObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 
@@ -64,8 +64,7 @@ def test_predictions():
                   height=rail_map.shape[0],
                   rail_generator=rail_from_GridTransitionMap_generator(rail),
                   number_of_agents=1,
-                  obs_builder_object=GlobalObsForRailEnv(),
-                  prediction_builder_object=DummyPredictorForRailEnv(max_depth=20)
+                  obs_builder_object=TreeObsForRailEnv(max_depth=20, predictor=DummyPredictorForRailEnv(max_depth=20)),
                   )
 
     env.reset()
@@ -74,7 +73,7 @@ def test_predictions():
     env.agents[0].position = (5, 6)
     env.agents[0].direction = 0
 
-    predictions = env.predict()
+    predictions = env.obs_builder.predictor.get()
     positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
     directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
     time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
-- 
GitLab