diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 3cb29a3f2fac6755b59f0b12fe2906e3ea2efa60..712d24258e042f0c81eb0afdc33ff880387cac95 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -2,7 +2,7 @@ Collection of environment-specific ObservationBuilder. """ import pprint -from typing import Optional, List, Dict +from typing import Optional, List, Dict, T, Tuple import numpy as np @@ -640,7 +640,7 @@ class LocalObsForRailEnv(ObservationBuilder): direction = np.identity(4)[agent.direction] return local_rail_obs, obs_map_state, obs_other_agents_state, direction - def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, tuple]: + def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list.