diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 81ab13e967836ca8b631b7cab5871763f2e0fce8..2ed47a5f8c18894c8e0a2108b7bc0a8d73b783f9 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -99,10 +99,10 @@ class ObservePredictions(ObservationBuilder): return observation - def _set_env(self, env: Environment): - self.env = env + def set_env(self, env: Environment): + super().set_env(env) if self.predictor: - self.predictor._set_env(self.env) + self.predictor.set_env(self.env) def main(args): diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index daae9b7c829f0a420cbee0d041b04aff16eb6a19..2d4df089eed08ee17f3d5f89147735b1b8570a7d 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -25,8 +25,9 @@ class ObservationBuilder: def __init__(self): self.observation_space = () + self.env = None - def _set_env(self, env: Environment): + def set_env(self, env: Environment): self.env = env def reset(self): @@ -91,9 +92,6 @@ class DummyObservationBuilder(ObservationBuilder): def __init__(self): super().__init__() - def _set_env(self, env: Environment): - self.env = env - def reset(self): pass diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py index e591d6cfae141aba474eb5424446515c12df89bb..ba839c7c270c72997f4229e9690d961bee0496e3 100644 --- a/flatland/core/env_prediction_builder.py +++ b/flatland/core/env_prediction_builder.py @@ -19,8 +19,9 @@ class PredictionBuilder: def __init__(self, max_depth: int = 20): self.max_depth = max_depth + self.env = None - def _set_env(self, env: Environment): + def set_env(self, env: Environment): self.env = env def reset(self): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 6312cca3e51fa0e60d2493587ff5d699fcd355c4..de9ee2a45ebdeabec5202be4f12593a82b4e20e4 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -491,10 +491,10 @@ class TreeObsForRailEnv(ObservationBuilder): unfolded[label] = observation_tree return unfolded - def _set_env(self, env: Environment): - self.env = env + def set_env(self, env: Environment): + super().set_env(env) if self.predictor: - self.predictor._set_env(self.env) + self.predictor.set_env(self.env) def _reverse_dir(self, direction): return int((direction + 2) % 4) @@ -522,9 +522,8 @@ class GlobalObsForRailEnv(ObservationBuilder): self.observation_space = () super(GlobalObsForRailEnv, self).__init__() - def _set_env(self, env: Environment): - super()._set_env(env) - + def set_env(self, env: Environment): + super().set_env(env) self.observation_space = [4, self.env.height, self.env.width] def reset(self): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 1805e8c6b01a9fb88db082dc0e7de7909800c0b6..d0add3086014c7ad07c29e01588cba380c26cbd7 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -161,7 +161,7 @@ class RailEnv(Environment): self.rewards = [0] * number_of_agents self.done = False self.obs_builder = obs_builder_object - self.obs_builder._set_env(self) + self.obs_builder.set_env(self) self._max_episode_steps = max_episode_steps self._elapsed_steps = 0