From ff83fdeb11ab325a2324e605cb57d00070bc3f4e Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Mon, 23 Sep 2019 10:43:17 +0200 Subject: [PATCH] Refactoring: add return types for ObservationBuilder.get(self, handle: int = 0) and ObservationBuilder.get_many(self, handles: Optional[List[int]] = None) --- examples/custom_observation_example_01_SimpleObs.py | 2 +- ...bservation_example_02_SingleAgentNavigationObs.py | 3 ++- ...stom_observation_example_03_ObservePredictions.py | 6 +++--- examples/debugging_example_DELETE.py | 3 ++- flatland/core/env_observation_builder.py | 4 ++-- flatland/envs/observations.py | 12 ++++++------ tests/test_flatland_malfunction.py | 4 ++-- 7 files changed, 18 insertions(+), 16 deletions(-) diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py index 7618720f..705169e9 100644 --- a/examples/custom_observation_example_01_SimpleObs.py +++ b/examples/custom_observation_example_01_SimpleObs.py @@ -23,7 +23,7 @@ class SimpleObs(ObservationBuilder): def reset(self): return - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> np.ndarray: observation = handle * np.ones((5,)) return observation diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py index 69ae584a..29c0437e 100644 --- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py +++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py @@ -2,6 +2,7 @@ import getopt import random import sys import time +from typing import List import numpy as np @@ -36,7 +37,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> List[int]: agent = self.env.agents[handle] possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 77665e28..d7c44753 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -2,7 +2,7 @@ import getopt import random import sys import time -from typing import Optional, List +from typing import Optional, List, Dict import numpy as np @@ -41,7 +41,7 @@ class ObservePredictions(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get_many(self, handles: Optional[List[int]] = None): + def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]: ''' Because we do not want to call the predictor seperately for every agent we implement the get_many function Here we can call the predictor just ones for all the agents and use the predictions to generate our observations @@ -69,7 +69,7 @@ class ObservePredictions(TreeObsForRailEnv): observations[h] = self.get(h) return observations - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> np.ndarray: ''' Lets write a simple observation which just indicates whether or not the own predicted path overlaps with other predicted paths at any time. This is useless for the task of navigation but might diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index 56209163..9f2ee252 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -1,5 +1,6 @@ import random import time +from typing import List import numpy as np @@ -33,7 +34,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> List[int]: agent = self.env.agents[handle] possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index be97a2ed..daae9b7c 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -97,8 +97,8 @@ class DummyObservationBuilder(ObservationBuilder): def reset(self): pass - def get_many(self, handles: Optional[List[int]] = None): + def get_many(self, handles: Optional[List[int]] = None) -> bool: return True - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> bool: return True diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 17985c44..a9a9b9dc 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -2,7 +2,7 @@ Collection of environment-specific ObservationBuilder. """ import pprint -from typing import Optional, List +from typing import Optional, List, Dict import numpy as np @@ -46,7 +46,7 @@ class TreeObsForRailEnv(ObservationBuilder): def reset(self): self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents} - def get_many(self, handles: Optional[List[int]] = None): + def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, List[int]]: """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. @@ -75,7 +75,7 @@ class TreeObsForRailEnv(ObservationBuilder): observations[h] = self.get(h) return observations - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> List[int]: """ Computes the current observation for agent `handle` in env @@ -534,7 +534,7 @@ class GlobalObsForRailEnv(ObservationBuilder): bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i, j] = np.array(bitlist) - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): obs_targets = np.zeros((self.env.height, self.env.width, 2)) obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) agents = self.env.agents @@ -600,7 +600,7 @@ class LocalObsForRailEnv(ObservationBuilder): bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i, j] = np.array(bitlist) - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray): agents = self.env.agents agent = agents[handle] @@ -640,7 +640,7 @@ class LocalObsForRailEnv(ObservationBuilder): direction = np.identity(4)[agent.direction] return local_rail_obs, obs_map_state, obs_other_agents_state, direction - def get_many(self, handles: Optional[List[int]] = None): + def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, (np.ndarray, np.ndarray, np.ndarray, np.ndarray)]: """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 1b3c6ade..33e5bb40 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -1,5 +1,5 @@ import random -from typing import Dict +from typing import Dict, List import numpy as np @@ -31,7 +31,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get(self, handle: int = 0): + def get(self, handle: int = 0) -> List[int]: agent = self.env.agents[handle] possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) -- GitLab