From 44f1381454be45a5a57ca70d9fb991423d6fc9bb Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Thu, 26 Sep 2019 09:32:19 +0200 Subject: [PATCH] make set_env non private --- ...ustom_observation_example_03_ObservePredictions.py | 6 +++--- flatland/core/env_observation_builder.py | 6 ++---- flatland/core/env_prediction_builder.py | 3 ++- flatland/envs/observations.py | 11 +++++------ flatland/envs/rail_env.py | 2 +- 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 81ab13e9..2ed47a5f 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -99,10 +99,10 @@ class ObservePredictions(ObservationBuilder): return observation - def _set_env(self, env: Environment): - self.env = env + def set_env(self, env: Environment): + super().set_env(env) if self.predictor: - self.predictor._set_env(self.env) + self.predictor.set_env(self.env) def main(args): diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index daae9b7c..2d4df089 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -25,8 +25,9 @@ class ObservationBuilder: def __init__(self): self.observation_space = () + self.env = None - def _set_env(self, env: Environment): + def set_env(self, env: Environment): self.env = env def reset(self): @@ -91,9 +92,6 @@ class DummyObservationBuilder(ObservationBuilder): def __init__(self): super().__init__() - def _set_env(self, env: Environment): - self.env = env - def reset(self): pass diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py index e591d6cf..ba839c7c 100644 --- a/flatland/core/env_prediction_builder.py +++ b/flatland/core/env_prediction_builder.py @@ -19,8 +19,9 @@ class PredictionBuilder: def __init__(self, max_depth: int = 20): self.max_depth = max_depth + self.env = None - def _set_env(self, env: Environment): + def set_env(self, env: Environment): self.env = env def reset(self): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 6312cca3..de9ee2a4 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -491,10 +491,10 @@ class TreeObsForRailEnv(ObservationBuilder): unfolded[label] = observation_tree return unfolded - def _set_env(self, env: Environment): - self.env = env + def set_env(self, env: Environment): + super().set_env(env) if self.predictor: - self.predictor._set_env(self.env) + self.predictor.set_env(self.env) def _reverse_dir(self, direction): return int((direction + 2) % 4) @@ -522,9 +522,8 @@ class GlobalObsForRailEnv(ObservationBuilder): self.observation_space = () super(GlobalObsForRailEnv, self).__init__() - def _set_env(self, env: Environment): - super()._set_env(env) - + def set_env(self, env: Environment): + super().set_env(env) self.observation_space = [4, self.env.height, self.env.width] def reset(self): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 1805e8c6..d0add308 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -161,7 +161,7 @@ class RailEnv(Environment): self.rewards = [0] * number_of_agents self.done = False self.obs_builder = obs_builder_object - self.obs_builder._set_env(self) + self.obs_builder.set_env(self) self._max_episode_steps = max_episode_steps self._elapsed_steps = 0 -- GitLab