diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py index 7618720fcdc798e5801b9ddb11d50214a35b2e19..705169e95c71137e93f92e8026f82a34d29d2182 100644 --- a/examples/custom_observation_example_01_SimpleObs.py +++ b/examples/custom_observation_example_01_SimpleObs.py @@ -23,7 +23,7 @@ class SimpleObs(ObservationBuilder): def reset(self): return - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> np.ndarray: observation = handle * np.ones((5,)) return observation diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py index 69ae584ab6b0833c499e7fb95fc1425c39f7adaf..29c0437e3a22c1432a410a9b892155c3a5c7cf99 100644 --- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py +++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py @@ -2,6 +2,7 @@ import getopt import random import sys import time +from typing import List import numpy as np @@ -36,7 +37,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> List[int]: agent = self.env.agents[handle] possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 77665e287dc7d3f49c38ee54ed1d5b372f1faa35..d7c4475379b26fac01918bb53e37eda066d29e87 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -2,7 +2,7 @@ import getopt import random import sys import time -from typing import Optional, List +from typing import Optional, List, Dict import numpy as np @@ -41,7 +41,7 @@ class ObservePredictions(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get_many(self, handles: Optional[List[int]] = None): + def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]: ''' Because we do not want to call the predictor seperately for every agent we implement the get_many function Here we can call the predictor just ones for all the agents and use the predictions to generate our observations @@ -69,7 +69,7 @@ class ObservePredictions(TreeObsForRailEnv): observations[h] = self.get(h) return observations - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> np.ndarray: ''' Lets write a simple observation which just indicates whether or not the own predicted path overlaps with other predicted paths at any time. This is useless for the task of navigation but might diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index 562091633725c7f11c87912aa76e21e0bb039a09..9f2ee252012c52647c16e8f0b7e91cf2f9e93fbe 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -1,5 +1,6 @@ import random import time +from typing import List import numpy as np @@ -33,7 +34,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> List[int]: agent = self.env.agents[handle] possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index be97a2eddb4f527d898afdcc257d445715ff8136..daae9b7c829f0a420cbee0d041b04aff16eb6a19 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -97,8 +97,8 @@ class DummyObservationBuilder(ObservationBuilder): def reset(self): pass - def get_many(self, handles: Optional[List[int]] = None): + def get_many(self, handles: Optional[List[int]] = None) -> bool: return True - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> bool: return True diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 17985c44c34e5284d33a1c513a45caa93b6efba7..712d24258e042f0c81eb0afdc33ff880387cac95 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -2,7 +2,7 @@ Collection of environment-specific ObservationBuilder. """ import pprint -from typing import Optional, List +from typing import Optional, List, Dict, T, Tuple import numpy as np @@ -46,7 +46,7 @@ class TreeObsForRailEnv(ObservationBuilder): def reset(self): self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents} - def get_many(self, handles: Optional[List[int]] = None): + def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, List[int]]: """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. @@ -75,7 +75,7 @@ class TreeObsForRailEnv(ObservationBuilder): observations[h] = self.get(h) return observations - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> List[int]: """ Computes the current observation for agent `handle` in env @@ -534,7 +534,7 @@ class GlobalObsForRailEnv(ObservationBuilder): bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i, j] = np.array(bitlist) - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): obs_targets = np.zeros((self.env.height, self.env.width, 2)) obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) agents = self.env.agents @@ -600,7 +600,7 @@ class LocalObsForRailEnv(ObservationBuilder): bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i, j] = np.array(bitlist) - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray): agents = self.env.agents agent = agents[handle] @@ -640,7 +640,7 @@ class LocalObsForRailEnv(ObservationBuilder): direction = np.identity(4)[agent.direction] return local_rail_obs, obs_map_state, obs_other_agents_state, direction - def get_many(self, handles: Optional[List[int]] = None): + def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 99c83e3b6d87470eb237afd0752231b7a378c758..8f7fe868f1f3bbfc766f0a1112aa9ddd384d6baa 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -1,5 +1,5 @@ import random -from typing import Dict +from typing import Dict, List import numpy as np @@ -31,7 +31,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> List[int]: agent = self.env.agents[handle] possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)