diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 4385d4da35202a658711127a719923f361552891..977a74a136d1f606db8d4687c47549b41d2e6185 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -22,7 +22,7 @@ class TreeObsForRailEnv(ObservationBuilder): For details about the features in the tree observation see the get() function. """ - observation_dim = 9 + def __init__(self, max_depth, predictor=None): super().__init__() @@ -43,6 +43,9 @@ class TreeObsForRailEnv(ObservationBuilder): self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] self.distance_map = None + # this needs to be updated when new features are added! + self.observation_dim = 9 + def reset(self): agents = self.env.agents nb_agents = len(agents)