From 82c81129e9885db9c698b03bf010eccc42e54cab Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Wed, 25 Sep 2019 13:55:40 +0200
Subject: [PATCH] directly inherit from ObservationBuilder and not from
 TreeObsForRailEnv when it is not necessary

---
 ...ion_example_02_SingleAgentNavigationObs.py | 14 +++++-------
 ...servation_example_03_ObservePredictions.py | 22 +++++++++----------
 examples/debugging_example_DELETE.py          | 14 +++++-------
 tests/test_flatland_malfunction.py            | 16 +++++---------
 4 files changed, 26 insertions(+), 40 deletions(-)

diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
index 29c0437e..e9c2a84e 100644
--- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py
+++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
@@ -6,8 +6,8 @@ from typing import List
 
 import numpy as np
 
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator
@@ -18,24 +18,20 @@ random.seed(100)
 np.random.seed(100)
 
 
-class SingleAgentNavigationObs(TreeObsForRailEnv):
+class SingleAgentNavigationObs(ObservationBuilder):
     """
-    We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
-    the minimum distances from each grid node to each agent's target.
-
-    We then build a representation vector with 3 binary components, indicating which of the 3 available directions
+    We build a representation vector with 3 binary components, indicating which of the 3 available directions
     for each agent (Left, Forward, Right) lead to the shortest path to its target.
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
 
     def __init__(self):
-        super().__init__(max_depth=0)
+        super().__init__()
         self.observation_space = [3]
 
     def reset(self):
-        # Recompute the distance map, if the environment has changed.
-        super().reset()
+        pass
 
     def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index d7c44753..81ab13e9 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -6,8 +6,9 @@ from typing import Optional, List, Dict
 
 import numpy as np
 
+from flatland.core.env import Environment
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid_utils import coordinate_to_position
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
@@ -20,26 +21,18 @@ random.seed(100)
 np.random.seed(100)
 
 
-class ObservePredictions(TreeObsForRailEnv):
+class ObservePredictions(ObservationBuilder):
     """
     We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation.
-
-    We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
-    the minimum distances from each grid node to each agent's target.
-
-    This is necessary so that we can pass the distance map to the ShortestPathPredictor
-
-    Here we also want to highlight how you can visualize your observation
     """
 
     def __init__(self, predictor):
-        super().__init__(max_depth=0)
+        super().__init__()
         self.observation_space = [10]
         self.predictor = predictor
 
     def reset(self):
-        # Recompute the distance map, if the environment has changed.
-        super().reset()
+        pass
 
     def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]:
         '''
@@ -106,6 +99,11 @@ class ObservePredictions(TreeObsForRailEnv):
 
         return observation
 
+    def _set_env(self, env: Environment):
+        self.env = env
+        if self.predictor:
+            self.predictor._set_env(self.env)
+
 
 def main(args):
     try:
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 9f2ee252..7cb7d962 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -4,8 +4,8 @@ from typing import List
 
 import numpy as np
 
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator
@@ -15,24 +15,20 @@ random.seed(1)
 np.random.seed(1)
 
 
-class SingleAgentNavigationObs(TreeObsForRailEnv):
+class SingleAgentNavigationObs(ObservationBuilder):
     """
-    We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
-    the minimum distances from each grid node to each agent's target.
-
-    We then build a representation vector with 3 binary components, indicating which of the 3 available directions
+    We build a representation vector with 3 binary components, indicating which of the 3 available directions
     for each agent (Left, Forward, Right) lead to the shortest path to its target.
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
 
     def __init__(self):
-        super().__init__(max_depth=0)
+        super().__init__()
         self.observation_space = [3]
 
     def reset(self):
-        # Recompute the distance map, if the environment has changed.
-        super().reset()
+        pass
 
     def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 8f7fe868..5104b06e 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -2,34 +2,30 @@ import random
 from typing import Dict, List
 
 import numpy as np
+from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
 
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
-from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
 
 
-class SingleAgentNavigationObs(TreeObsForRailEnv):
+class SingleAgentNavigationObs(ObservationBuilder):
     """
-    We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
-    the minimum distances from each grid node to each agent's target.
-
-    We then build a representation vector with 3 binary components, indicating which of the 3 available directions
+    We build a representation vector with 3 binary components, indicating which of the 3 available directions
     for each agent (Left, Forward, Right) lead to the shortest path to its target.
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
 
     def __init__(self):
-        super().__init__(max_depth=0)
+        super().__init__()
         self.observation_space = [3]
 
     def reset(self):
-        # Recompute the distance map, if the environment has changed.
-        super().reset()
+        pass
 
     def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
-- 
GitLab