diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 30c0fabeae0d7f9dfd49ba05983ad90f457cd5f2..a08a5f73cee9dd0c389597cc4abb4dfaefd52cc0 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -58,7 +58,7 @@ class TreeObsForRailEnv(ObservationBuilder): def reset(self): self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents} - def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, List[int]]: + def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]: """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. @@ -87,7 +87,7 @@ class TreeObsForRailEnv(ObservationBuilder): observations[h] = self.get(h) return observations - def get(self, handle: int = 0) -> List[int]: + def get(self, handle: int = 0) -> Node: """ Computes the current observation for agent `handle` in env