diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py index 70a2515b789031bf7aec4eeb4a5e9fc495a045bf..7618720fcdc798e5801b9ddb11d50214a35b2e19 100644 --- a/examples/custom_observation_example_01_SimpleObs.py +++ b/examples/custom_observation_example_01_SimpleObs.py @@ -17,12 +17,13 @@ class SimpleObs(ObservationBuilder): """ def __init__(self): + super().__init__() self.observation_space = [5] def reset(self): return - def get(self, handle): + def get(self, handle: int = 0): observation = handle * np.ones((5,)) return observation diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py index 317372da390693dd51b53d411c4d5615582183b0..4977d1f48d55ad0d2659ddcc02329a7f9ad0d47b 100644 --- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py +++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py @@ -35,7 +35,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get(self, handle): + def get(self, handle: int = 0): agent = self.env.agents[handle] possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 9238a2af4137e37e9d79bc3c1aaade2bb987403e..7a740b19efc70c54940df31a41401fb9d0eeee3e 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -2,6 +2,7 @@ import getopt import random import sys import time +from typing import Optional, List import numpy as np @@ -39,7 +40,7 @@ class ObservePredictions(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get_many(self, handles=None): + def get_many(self, handles: Optional[List[int]] = None): ''' Because we do not want to call the predictor seperately for every agent we implement the get_many function Here we can call the predictor just ones for all the agents and use the predictions to generate our observations @@ -50,6 +51,10 @@ class ObservePredictions(TreeObsForRailEnv): self.predictions = self.predictor.get() self.predicted_pos = {} + + if handles is None: + handles = [] + for t in range(len(self.predictions[0])): pos_list = [] for a in handles: @@ -63,7 +68,7 @@ class ObservePredictions(TreeObsForRailEnv): observations[h] = self.get(h) return observations - def get(self, handle): + def get(self, handle: int = 0): ''' Lets write a simple observation which just indicates whether or not the own predicted path overlaps with other predicted paths at any time. This is useless for the task of navigation but might diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index 8aef94c23311fc693c229924953164afb5fec8ab..562091633725c7f11c87912aa76e21e0bb039a09 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -33,7 +33,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get(self, handle): + def get(self, handle: int = 0): agent = self.env.agents[handle] possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index ba79e7fc0e46d8182951b5a8e2520d7ffc9eacb0..be97a2eddb4f527d898afdcc257d445715ff8136 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -8,8 +8,12 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and multi-agent environments. """ +from typing import Optional, List + import numpy as np +from flatland.core.env import Environment + class ObservationBuilder: """ @@ -22,7 +26,7 @@ class ObservationBuilder: def __init__(self): self.observation_space = () - def _set_env(self, env): + def _set_env(self, env: Environment): self.env = env def reset(self): @@ -31,7 +35,7 @@ class ObservationBuilder: """ raise NotImplementedError() - def get_many(self, handles=[]): + def get_many(self, handles: Optional[List[int]] = None): """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. @@ -48,11 +52,13 @@ class ObservationBuilder: `handles` as keys. """ observations = {} + if handles is None: + handles = [] for h in handles: observations[h] = self.get(h) return observations - def get(self, handle=0): + def get(self, handle: int = 0): """ Called whenever an observation has to be computed for the `env` environment, possibly for each agent independently (agent id `handle`). @@ -83,16 +89,16 @@ class DummyObservationBuilder(ObservationBuilder): """ def __init__(self): - self.observation_space = () + super().__init__() - def _set_env(self, env): + def _set_env(self, env: Environment): self.env = env def reset(self): pass - def get_many(self, handles=[]): + def get_many(self, handles: Optional[List[int]] = None): return True - def get(self, handle=0): + def get(self, handle: int = 0): return True diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 31ad16438cbf63a532b8e90bdbd9efb704e755b7..17985c44c34e5284d33a1c513a45caa93b6efba7 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -2,9 +2,11 @@ Collection of environment-specific ObservationBuilder. """ import pprint +from typing import Optional, List import numpy as np +from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_prediction_builder import PredictionBuilder from flatland.core.grid.grid4_utils import get_new_position @@ -44,7 +46,7 @@ class TreeObsForRailEnv(ObservationBuilder): def reset(self): self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents} - def get_many(self, handles=None): + def get_many(self, handles: Optional[List[int]] = None): """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. @@ -73,7 +75,7 @@ class TreeObsForRailEnv(ObservationBuilder): observations[h] = self.get(h) return observations - def get(self, handle): + def get(self, handle: int = 0): """ Computes the current observation for agent `handle` in env @@ -488,7 +490,7 @@ class TreeObsForRailEnv(ObservationBuilder): unfolded[label] = observation_tree return unfolded - def _set_env(self, env): + def _set_env(self, env: Environment): self.env = env if self.predictor: self.predictor._set_env(self.env) @@ -519,7 +521,7 @@ class GlobalObsForRailEnv(ObservationBuilder): self.observation_space = () super(GlobalObsForRailEnv, self).__init__() - def _set_env(self, env): + def _set_env(self, env: Environment): super()._set_env(env) self.observation_space = [4, self.env.height, self.env.width] @@ -532,7 +534,7 @@ class GlobalObsForRailEnv(ObservationBuilder): bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i, j] = np.array(bitlist) - def get(self, handle): + def get(self, handle: int = 0): obs_targets = np.zeros((self.env.height, self.env.width, 2)) obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) agents = self.env.agents @@ -598,7 +600,7 @@ class LocalObsForRailEnv(ObservationBuilder): bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i, j] = np.array(bitlist) - def get(self, handle): + def get(self, handle: int = 0): agents = self.env.agents agent = agents[handle] @@ -638,13 +640,15 @@ class LocalObsForRailEnv(ObservationBuilder): direction = np.identity(4)[agent.direction] return local_rail_obs, obs_map_state, obs_other_agents_state, direction - def get_many(self, handles=None): + def get_many(self, handles: Optional[List[int]] = None): """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. """ observations = {} + if handles is None: + handles = [] for h in handles: observations[h] = self.get(h) return observations diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 294ffab233458f1f3b98c18be50743ba65bd2d73..c81ef9dc82df0817f3f3fc42798392d7ffdbcf5e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -11,6 +11,7 @@ import msgpack_numpy as m import numpy as np from flatland.core.env import Environment +from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap @@ -114,7 +115,7 @@ class RailEnv(Environment): rail_generator: RailGenerator = random_rail_generator(), schedule_generator: ScheduleGenerator = random_schedule_generator(), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2), + obs_builder_object: ObservationBuilder = TreeObsForRailEnv(max_depth=2), max_episode_steps=None, stochastic_data=None ): diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index fde9df58663993ae170c4c1e3fea55637feb4282..3c0fd834d79be1539f3403a8797beec2fa375b18 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -32,7 +32,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): # Recompute the distance map, if the environment has changed. super().reset() - def get(self, handle): + def get(self, handle: int = 0): agent = self.env.agents[handle] possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)