From 44f1381454be45a5a57ca70d9fb991423d6fc9bb Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Thu, 26 Sep 2019 09:32:19 +0200
Subject: [PATCH] make set_env non private

---
 ...ustom_observation_example_03_ObservePredictions.py |  6 +++---
 flatland/core/env_observation_builder.py              |  6 ++----
 flatland/core/env_prediction_builder.py               |  3 ++-
 flatland/envs/observations.py                         | 11 +++++------
 flatland/envs/rail_env.py                             |  2 +-
 5 files changed, 13 insertions(+), 15 deletions(-)

diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index 81ab13e9..2ed47a5f 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -99,10 +99,10 @@ class ObservePredictions(ObservationBuilder):
 
         return observation
 
-    def _set_env(self, env: Environment):
-        self.env = env
+    def set_env(self, env: Environment):
+        super().set_env(env)
         if self.predictor:
-            self.predictor._set_env(self.env)
+            self.predictor.set_env(self.env)
 
 
 def main(args):
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index daae9b7c..2d4df089 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -25,8 +25,9 @@ class ObservationBuilder:
 
     def __init__(self):
         self.observation_space = ()
+        self.env = None
 
-    def _set_env(self, env: Environment):
+    def set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
@@ -91,9 +92,6 @@ class DummyObservationBuilder(ObservationBuilder):
     def __init__(self):
         super().__init__()
 
-    def _set_env(self, env: Environment):
-        self.env = env
-
     def reset(self):
         pass
 
diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index e591d6cf..ba839c7c 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -19,8 +19,9 @@ class PredictionBuilder:
 
     def __init__(self, max_depth: int = 20):
         self.max_depth = max_depth
+        self.env = None
 
-    def _set_env(self, env: Environment):
+    def set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 6312cca3..de9ee2a4 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -491,10 +491,10 @@ class TreeObsForRailEnv(ObservationBuilder):
                 unfolded[label] = observation_tree
         return unfolded
 
-    def _set_env(self, env: Environment):
-        self.env = env
+    def set_env(self, env: Environment):
+        super().set_env(env)
         if self.predictor:
-            self.predictor._set_env(self.env)
+            self.predictor.set_env(self.env)
 
     def _reverse_dir(self, direction):
         return int((direction + 2) % 4)
@@ -522,9 +522,8 @@ class GlobalObsForRailEnv(ObservationBuilder):
         self.observation_space = ()
         super(GlobalObsForRailEnv, self).__init__()
 
-    def _set_env(self, env: Environment):
-        super()._set_env(env)
-
+    def set_env(self, env: Environment):
+        super().set_env(env)
         self.observation_space = [4, self.env.height, self.env.width]
 
     def reset(self):
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 1805e8c6..d0add308 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -161,7 +161,7 @@ class RailEnv(Environment):
         self.rewards = [0] * number_of_agents
         self.done = False
         self.obs_builder = obs_builder_object
-        self.obs_builder._set_env(self)
+        self.obs_builder.set_env(self)
 
         self._max_episode_steps = max_episode_steps
         self._elapsed_steps = 0
-- 
GitLab