From 9b86c0b3706560af200815fa8d81bf19b70e0c7c Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Fri, 20 Sep 2019 10:39:30 +0200
Subject: [PATCH] Refactoring: add type hints for ObservationBuilder

---
 ...custom_observation_example_01_SimpleObs.py |  3 ++-
 ...ion_example_02_SingleAgentNavigationObs.py |  2 +-
 ...servation_example_03_ObservePredictions.py |  9 +++++++--
 examples/debugging_example_DELETE.py          |  2 +-
 flatland/core/env_observation_builder.py      | 20 ++++++++++++-------
 flatland/envs/observations.py                 | 18 ++++++++++-------
 flatland/envs/rail_env.py                     |  3 ++-
 tests/test_flatland_malfunction.py            |  2 +-
 8 files changed, 38 insertions(+), 21 deletions(-)

diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py
index 70a2515b..7618720f 100644
--- a/examples/custom_observation_example_01_SimpleObs.py
+++ b/examples/custom_observation_example_01_SimpleObs.py
@@ -17,12 +17,13 @@ class SimpleObs(ObservationBuilder):
     """
 
     def __init__(self):
+        super().__init__()
         self.observation_space = [5]
 
     def reset(self):
         return
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         observation = handle * np.ones((5,))
         return observation
 
diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
index 317372da..4977d1f4 100644
--- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py
+++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
@@ -35,7 +35,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
         # Recompute the distance map, if the environment has changed.
         super().reset()
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         agent = self.env.agents[handle]
 
         possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index 9238a2af..7a740b19 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -2,6 +2,7 @@ import getopt
 import random
 import sys
 import time
+from typing import Optional, List
 
 import numpy as np
 
@@ -39,7 +40,7 @@ class ObservePredictions(TreeObsForRailEnv):
         # Recompute the distance map, if the environment has changed.
         super().reset()
 
-    def get_many(self, handles=None):
+    def get_many(self, handles: Optional[List[int]] = None):
         '''
         Because we do not want to call the predictor seperately for every agent we implement the get_many function
         Here we can call the predictor just ones for all the agents and use the predictions to generate our observations
@@ -50,6 +51,10 @@ class ObservePredictions(TreeObsForRailEnv):
         self.predictions = self.predictor.get()
 
         self.predicted_pos = {}
+
+        if handles is None:
+            handles = []
+
         for t in range(len(self.predictions[0])):
             pos_list = []
             for a in handles:
@@ -63,7 +68,7 @@ class ObservePredictions(TreeObsForRailEnv):
             observations[h] = self.get(h)
         return observations
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         '''
         Lets write a simple observation which just indicates whether or not the own predicted path
         overlaps with other predicted paths at any time. This is useless for the task of navigation but might
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 8aef94c2..56209163 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -33,7 +33,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
         # Recompute the distance map, if the environment has changed.
         super().reset()
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         agent = self.env.agents[handle]
 
         possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index ba79e7fc..be97a2ed 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -8,8 +8,12 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and
 multi-agent environments.
 
 """
+from typing import Optional, List
+
 import numpy as np
 
+from flatland.core.env import Environment
+
 
 class ObservationBuilder:
     """
@@ -22,7 +26,7 @@ class ObservationBuilder:
     def __init__(self):
         self.observation_space = ()
 
-    def _set_env(self, env):
+    def _set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
@@ -31,7 +35,7 @@ class ObservationBuilder:
         """
         raise NotImplementedError()
 
-    def get_many(self, handles=[]):
+    def get_many(self, handles: Optional[List[int]] = None):
         """
         Called whenever an observation has to be computed for the `env` environment, for each agent with handle
         in the `handles` list.
@@ -48,11 +52,13 @@ class ObservationBuilder:
             `handles` as keys.
         """
         observations = {}
+        if handles is None:
+            handles = []
         for h in handles:
             observations[h] = self.get(h)
         return observations
 
-    def get(self, handle=0):
+    def get(self, handle: int = 0):
         """
         Called whenever an observation has to be computed for the `env` environment, possibly
         for each agent independently (agent id `handle`).
@@ -83,16 +89,16 @@ class DummyObservationBuilder(ObservationBuilder):
     """
 
     def __init__(self):
-        self.observation_space = ()
+        super().__init__()
 
-    def _set_env(self, env):
+    def _set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
         pass
 
-    def get_many(self, handles=[]):
+    def get_many(self, handles: Optional[List[int]] = None):
         return True
 
-    def get(self, handle=0):
+    def get(self, handle: int = 0):
         return True
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 31ad1643..17985c44 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -2,9 +2,11 @@
 Collection of environment-specific ObservationBuilder.
 """
 import pprint
+from typing import Optional, List
 
 import numpy as np
 
+from flatland.core.env import Environment
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.env_prediction_builder import PredictionBuilder
 from flatland.core.grid.grid4_utils import get_new_position
@@ -44,7 +46,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     def reset(self):
         self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
 
-    def get_many(self, handles=None):
+    def get_many(self, handles: Optional[List[int]] = None):
         """
         Called whenever an observation has to be computed for the `env` environment, for each agent with handle
         in the `handles` list.
@@ -73,7 +75,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             observations[h] = self.get(h)
         return observations
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         """
         Computes the current observation for agent `handle` in env
 
@@ -488,7 +490,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 unfolded[label] = observation_tree
         return unfolded
 
-    def _set_env(self, env):
+    def _set_env(self, env: Environment):
         self.env = env
         if self.predictor:
             self.predictor._set_env(self.env)
@@ -519,7 +521,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
         self.observation_space = ()
         super(GlobalObsForRailEnv, self).__init__()
 
-    def _set_env(self, env):
+    def _set_env(self, env: Environment):
         super()._set_env(env)
 
         self.observation_space = [4, self.env.height, self.env.width]
@@ -532,7 +534,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
                 bitlist = [0] * (16 - len(bitlist)) + bitlist
                 self.rail_obs[i, j] = np.array(bitlist)
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         obs_targets = np.zeros((self.env.height, self.env.width, 2))
         obs_agents_state = np.zeros((self.env.height, self.env.width, 4))
         agents = self.env.agents
@@ -598,7 +600,7 @@ class LocalObsForRailEnv(ObservationBuilder):
                 bitlist = [0] * (16 - len(bitlist)) + bitlist
                 self.rail_obs[i, j] = np.array(bitlist)
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         agents = self.env.agents
         agent = agents[handle]
 
@@ -638,13 +640,15 @@ class LocalObsForRailEnv(ObservationBuilder):
         direction = np.identity(4)[agent.direction]
         return local_rail_obs, obs_map_state, obs_other_agents_state, direction
 
-    def get_many(self, handles=None):
+    def get_many(self, handles: Optional[List[int]] = None):
         """
         Called whenever an observation has to be computed for the `env` environment, for each agent with handle
         in the `handles` list.
         """
 
         observations = {}
+        if handles is None:
+            handles = []
         for h in handles:
             observations[h] = self.get(h)
         return observations
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 294ffab2..c81ef9dc 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -11,6 +11,7 @@ import msgpack_numpy as m
 import numpy as np
 
 from flatland.core.env import Environment
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
@@ -114,7 +115,7 @@ class RailEnv(Environment):
                  rail_generator: RailGenerator = random_rail_generator(),
                  schedule_generator: ScheduleGenerator = random_schedule_generator(),
                  number_of_agents=1,
-                 obs_builder_object=TreeObsForRailEnv(max_depth=2),
+                 obs_builder_object: ObservationBuilder = TreeObsForRailEnv(max_depth=2),
                  max_episode_steps=None,
                  stochastic_data=None
                  ):
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index fde9df58..3c0fd834 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -32,7 +32,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
         # Recompute the distance map, if the environment has changed.
         super().reset()
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         agent = self.env.agents[handle]
 
         possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
-- 
GitLab