diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 4385d4da35202a658711127a719923f361552891..15d42d793b50dabe3627f6a4e601ca614f0e68b0 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -22,7 +22,7 @@ class TreeObsForRailEnv(ObservationBuilder): For details about the features in the tree observation see the get() function. """ - observation_dim = 9 + def __init__(self, max_depth, predictor=None): super().__init__() @@ -34,6 +34,7 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth + 1): size += pow4 pow4 *= 4 + self.observation_dim = 9 self.observation_space = [size * self.observation_dim] self.location_has_agent = {} self.location_has_agent_direction = {}