Skip to content
Snippets Groups Projects
Commit 44f13814 authored by u229589's avatar u229589
Browse files

make set_env non private

parent 82c81129
No related branches found
No related tags found
No related merge requests found
......@@ -99,10 +99,10 @@ class ObservePredictions(ObservationBuilder):
return observation
def _set_env(self, env: Environment):
self.env = env
def set_env(self, env: Environment):
super().set_env(env)
if self.predictor:
self.predictor._set_env(self.env)
self.predictor.set_env(self.env)
def main(args):
......
......@@ -25,8 +25,9 @@ class ObservationBuilder:
def __init__(self):
self.observation_space = ()
self.env = None
def _set_env(self, env: Environment):
def set_env(self, env: Environment):
self.env = env
def reset(self):
......@@ -91,9 +92,6 @@ class DummyObservationBuilder(ObservationBuilder):
def __init__(self):
super().__init__()
def _set_env(self, env: Environment):
self.env = env
def reset(self):
pass
......
......@@ -19,8 +19,9 @@ class PredictionBuilder:
def __init__(self, max_depth: int = 20):
self.max_depth = max_depth
self.env = None
def _set_env(self, env: Environment):
def set_env(self, env: Environment):
self.env = env
def reset(self):
......
......@@ -491,10 +491,10 @@ class TreeObsForRailEnv(ObservationBuilder):
unfolded[label] = observation_tree
return unfolded
def _set_env(self, env: Environment):
self.env = env
def set_env(self, env: Environment):
super().set_env(env)
if self.predictor:
self.predictor._set_env(self.env)
self.predictor.set_env(self.env)
def _reverse_dir(self, direction):
return int((direction + 2) % 4)
......@@ -522,9 +522,8 @@ class GlobalObsForRailEnv(ObservationBuilder):
self.observation_space = ()
super(GlobalObsForRailEnv, self).__init__()
def _set_env(self, env: Environment):
super()._set_env(env)
def set_env(self, env: Environment):
super().set_env(env)
self.observation_space = [4, self.env.height, self.env.width]
def reset(self):
......
......@@ -161,7 +161,7 @@ class RailEnv(Environment):
self.rewards = [0] * number_of_agents
self.done = False
self.obs_builder = obs_builder_object
self.obs_builder._set_env(self)
self.obs_builder.set_env(self)
self._max_episode_steps = max_episode_steps
self._elapsed_steps = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment