From 9b86c0b3706560af200815fa8d81bf19b70e0c7c Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Fri, 20 Sep 2019 10:39:30 +0200 Subject: [PATCH] Refactoring: add type hints for ObservationBuilder --- ...custom_observation_example_01_SimpleObs.py | 3 ++- ...ion_example_02_SingleAgentNavigationObs.py | 2 +- ...servation_example_03_ObservePredictions.py | 9 +++++++-- examples/debugging_example_DELETE.py | 2 +- flatland/core/env_observation_builder.py | 20 ++++++++++++------- flatland/envs/observations.py | 18 ++++++++++------- flatland/envs/rail_env.py | 3 ++- tests/test_flatland_malfunction.py | 2 +- 8 files changed, 38 insertions(+), 21 deletions(-) diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py index 70a2515b..7618720f 100644 --- a/examples/custom_observation_example_01_SimpleObs.py +++ b/examples/custom_observation_example_01_SimpleObs.py @@ -17,12 +17,13 @@ class SimpleObs(ObservationBuilder): """ def __init__(self): + super().__init__() self.observation_space = [5] def reset(self): return - def get(self, handle): + def get(self, handle: int = 0): observation = handle * np.ones((5,)) return observation diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py index 317372da..4977d1f4 100644 --- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py +++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py @@ -35,7 +35,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get(self, handle): + def get(self, handle: int = 0): agent = self.env.agents[handle] possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 9238a2af..7a740b19 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -2,6 +2,7 @@ import getopt import random import sys import time +from typing import Optional, List import numpy as np @@ -39,7 +40,7 @@ class ObservePredictions(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get_many(self, handles=None): + def get_many(self, handles: Optional[List[int]] = None): ''' Because we do not want to call the predictor seperately for every agent we implement the get_many function Here we can call the predictor just ones for all the agents and use the predictions to generate our observations @@ -50,6 +51,10 @@ class ObservePredictions(TreeObsForRailEnv): self.predictions = self.predictor.get() self.predicted_pos = {} + + if handles is None: + handles = [] + for t in range(len(self.predictions[0])): pos_list = [] for a in handles: @@ -63,7 +68,7 @@ class ObservePredictions(TreeObsForRailEnv): observations[h] = self.get(h) return observations - def get(self, handle): + def get(self, handle: int = 0): ''' Lets write a simple observation which just indicates whether or not the own predicted path overlaps with other predicted paths at any time. This is useless for the task of navigation but might diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index 8aef94c2..56209163 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -33,7 +33,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get(self, handle): + def get(self, handle: int = 0): agent = self.env.agents[handle] possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index ba79e7fc..be97a2ed 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -8,8 +8,12 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and multi-agent environments. """ +from typing import Optional, List + import numpy as np +from flatland.core.env import Environment + class ObservationBuilder: """ @@ -22,7 +26,7 @@ class ObservationBuilder: def __init__(self): self.observation_space = () - def _set_env(self, env): + def _set_env(self, env: Environment): self.env = env def reset(self): @@ -31,7 +35,7 @@ class ObservationBuilder: """ raise NotImplementedError() - def get_many(self, handles=[]): + def get_many(self, handles: Optional[List[int]] = None): """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. @@ -48,11 +52,13 @@ class ObservationBuilder: `handles` as keys. """ observations = {} + if handles is None: + handles = [] for h in handles: observations[h] = self.get(h) return observations - def get(self, handle=0): + def get(self, handle: int = 0): """ Called whenever an observation has to be computed for the `env` environment, possibly for each agent independently (agent id `handle`). @@ -83,16 +89,16 @@ class DummyObservationBuilder(ObservationBuilder): """ def __init__(self): - self.observation_space = () + super().__init__() - def _set_env(self, env): + def _set_env(self, env: Environment): self.env = env def reset(self): pass - def get_many(self, handles=[]): + def get_many(self, handles: Optional[List[int]] = None): return True - def get(self, handle=0): + def get(self, handle: int = 0): return True diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 31ad1643..17985c44 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -2,9 +2,11 @@ Collection of environment-specific ObservationBuilder. """ import pprint +from typing import Optional, List import numpy as np +from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_prediction_builder import PredictionBuilder from flatland.core.grid.grid4_utils import get_new_position @@ -44,7 +46,7 @@ class TreeObsForRailEnv(ObservationBuilder): def reset(self): self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents} - def get_many(self, handles=None): + def get_many(self, handles: Optional[List[int]] = None): """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. @@ -73,7 +75,7 @@ class TreeObsForRailEnv(ObservationBuilder): observations[h] = self.get(h) return observations - def get(self, handle): + def get(self, handle: int = 0): """ Computes the current observation for agent `handle` in env @@ -488,7 +490,7 @@ class TreeObsForRailEnv(ObservationBuilder): unfolded[label] = observation_tree return unfolded - def _set_env(self, env): + def _set_env(self, env: Environment): self.env = env if self.predictor: self.predictor._set_env(self.env) @@ -519,7 +521,7 @@ class GlobalObsForRailEnv(ObservationBuilder): self.observation_space = () super(GlobalObsForRailEnv, self).__init__() - def _set_env(self, env): + def _set_env(self, env: Environment): super()._set_env(env) self.observation_space = [4, self.env.height, self.env.width] @@ -532,7 +534,7 @@ class GlobalObsForRailEnv(ObservationBuilder): bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i, j] = np.array(bitlist) - def get(self, handle): + def get(self, handle: int = 0): obs_targets = np.zeros((self.env.height, self.env.width, 2)) obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) agents = self.env.agents @@ -598,7 +600,7 @@ class LocalObsForRailEnv(ObservationBuilder): bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i, j] = np.array(bitlist) - def get(self, handle): + def get(self, handle: int = 0): agents = self.env.agents agent = agents[handle] @@ -638,13 +640,15 @@ class LocalObsForRailEnv(ObservationBuilder): direction = np.identity(4)[agent.direction] return local_rail_obs, obs_map_state, obs_other_agents_state, direction - def get_many(self, handles=None): + def get_many(self, handles: Optional[List[int]] = None): """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. """ observations = {} + if handles is None: + handles = [] for h in handles: observations[h] = self.get(h) return observations diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 294ffab2..c81ef9dc 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -11,6 +11,7 @@ import msgpack_numpy as m import numpy as np from flatland.core.env import Environment +from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap @@ -114,7 +115,7 @@ class RailEnv(Environment): rail_generator: RailGenerator = random_rail_generator(), schedule_generator: ScheduleGenerator = random_schedule_generator(), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2), + obs_builder_object: ObservationBuilder = TreeObsForRailEnv(max_depth=2), max_episode_steps=None, stochastic_data=None ): diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index fde9df58..3c0fd834 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -32,7 +32,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get(self, handle): + def get(self, handle: int = 0): agent = self.env.agents[handle] possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) -- GitLab