Skip to content
Snippets Groups Projects
Commit 44f13814 authored by u229589's avatar u229589
Browse files

make set_env non private

parent 82c81129
No related branches found
No related tags found
No related merge requests found
...@@ -99,10 +99,10 @@ class ObservePredictions(ObservationBuilder): ...@@ -99,10 +99,10 @@ class ObservePredictions(ObservationBuilder):
return observation return observation
def _set_env(self, env: Environment): def set_env(self, env: Environment):
self.env = env super().set_env(env)
if self.predictor: if self.predictor:
self.predictor._set_env(self.env) self.predictor.set_env(self.env)
def main(args): def main(args):
......
...@@ -25,8 +25,9 @@ class ObservationBuilder: ...@@ -25,8 +25,9 @@ class ObservationBuilder:
def __init__(self): def __init__(self):
self.observation_space = () self.observation_space = ()
self.env = None
def _set_env(self, env: Environment): def set_env(self, env: Environment):
self.env = env self.env = env
def reset(self): def reset(self):
...@@ -91,9 +92,6 @@ class DummyObservationBuilder(ObservationBuilder): ...@@ -91,9 +92,6 @@ class DummyObservationBuilder(ObservationBuilder):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _set_env(self, env: Environment):
self.env = env
def reset(self): def reset(self):
pass pass
......
...@@ -19,8 +19,9 @@ class PredictionBuilder: ...@@ -19,8 +19,9 @@ class PredictionBuilder:
def __init__(self, max_depth: int = 20): def __init__(self, max_depth: int = 20):
self.max_depth = max_depth self.max_depth = max_depth
self.env = None
def _set_env(self, env: Environment): def set_env(self, env: Environment):
self.env = env self.env = env
def reset(self): def reset(self):
......
...@@ -491,10 +491,10 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -491,10 +491,10 @@ class TreeObsForRailEnv(ObservationBuilder):
unfolded[label] = observation_tree unfolded[label] = observation_tree
return unfolded return unfolded
def _set_env(self, env: Environment): def set_env(self, env: Environment):
self.env = env super().set_env(env)
if self.predictor: if self.predictor:
self.predictor._set_env(self.env) self.predictor.set_env(self.env)
def _reverse_dir(self, direction): def _reverse_dir(self, direction):
return int((direction + 2) % 4) return int((direction + 2) % 4)
...@@ -522,9 +522,8 @@ class GlobalObsForRailEnv(ObservationBuilder): ...@@ -522,9 +522,8 @@ class GlobalObsForRailEnv(ObservationBuilder):
self.observation_space = () self.observation_space = ()
super(GlobalObsForRailEnv, self).__init__() super(GlobalObsForRailEnv, self).__init__()
def _set_env(self, env: Environment): def set_env(self, env: Environment):
super()._set_env(env) super().set_env(env)
self.observation_space = [4, self.env.height, self.env.width] self.observation_space = [4, self.env.height, self.env.width]
def reset(self): def reset(self):
......
...@@ -161,7 +161,7 @@ class RailEnv(Environment): ...@@ -161,7 +161,7 @@ class RailEnv(Environment):
self.rewards = [0] * number_of_agents self.rewards = [0] * number_of_agents
self.done = False self.done = False
self.obs_builder = obs_builder_object self.obs_builder = obs_builder_object
self.obs_builder._set_env(self) self.obs_builder.set_env(self)
self._max_episode_steps = max_episode_steps self._max_episode_steps = max_episode_steps
self._elapsed_steps = 0 self._elapsed_steps = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment