From ff83fdeb11ab325a2324e605cb57d00070bc3f4e Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Mon, 23 Sep 2019 10:43:17 +0200
Subject: [PATCH] Refactoring: add return types for
 ObservationBuilder.get(self, handle: int = 0) and
 ObservationBuilder.get_many(self, handles: Optional[List[int]] = None)

---
 examples/custom_observation_example_01_SimpleObs.py  |  2 +-
 ...bservation_example_02_SingleAgentNavigationObs.py |  3 ++-
 ...stom_observation_example_03_ObservePredictions.py |  6 +++---
 examples/debugging_example_DELETE.py                 |  3 ++-
 flatland/core/env_observation_builder.py             |  4 ++--
 flatland/envs/observations.py                        | 12 ++++++------
 tests/test_flatland_malfunction.py                   |  4 ++--
 7 files changed, 18 insertions(+), 16 deletions(-)

diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py
index 7618720f..705169e9 100644
--- a/examples/custom_observation_example_01_SimpleObs.py
+++ b/examples/custom_observation_example_01_SimpleObs.py
@@ -23,7 +23,7 @@ class SimpleObs(ObservationBuilder):
     def reset(self):
         return
 
-    def get(self, handle: int = 0):
+    def get(self, handle: int = 0) -> np.ndarray:
         observation = handle * np.ones((5,))
         return observation
 
diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
index 69ae584a..29c0437e 100644
--- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py
+++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
@@ -2,6 +2,7 @@ import getopt
 import random
 import sys
 import time
+from typing import List
 
 import numpy as np
 
@@ -36,7 +37,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
         # Recompute the distance map, if the environment has changed.
         super().reset()
 
-    def get(self, handle: int = 0):
+    def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
 
         possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index 77665e28..d7c44753 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -2,7 +2,7 @@ import getopt
 import random
 import sys
 import time
-from typing import Optional, List
+from typing import Optional, List, Dict
 
 import numpy as np
 
@@ -41,7 +41,7 @@ class ObservePredictions(TreeObsForRailEnv):
         # Recompute the distance map, if the environment has changed.
         super().reset()
 
-    def get_many(self, handles: Optional[List[int]] = None):
+    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]:
         '''
         Because we do not want to call the predictor seperately for every agent we implement the get_many function
         Here we can call the predictor just ones for all the agents and use the predictions to generate our observations
@@ -69,7 +69,7 @@ class ObservePredictions(TreeObsForRailEnv):
             observations[h] = self.get(h)
         return observations
 
-    def get(self, handle: int = 0):
+    def get(self, handle: int = 0) -> np.ndarray:
         '''
         Lets write a simple observation which just indicates whether or not the own predicted path
         overlaps with other predicted paths at any time. This is useless for the task of navigation but might
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 56209163..9f2ee252 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -1,5 +1,6 @@
 import random
 import time
+from typing import List
 
 import numpy as np
 
@@ -33,7 +34,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
         # Recompute the distance map, if the environment has changed.
         super().reset()
 
-    def get(self, handle: int = 0):
+    def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
 
         possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index be97a2ed..daae9b7c 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -97,8 +97,8 @@ class DummyObservationBuilder(ObservationBuilder):
     def reset(self):
         pass
 
-    def get_many(self, handles: Optional[List[int]] = None):
+    def get_many(self, handles: Optional[List[int]] = None) -> bool:
         return True
 
-    def get(self, handle: int = 0):
+    def get(self, handle: int = 0) -> bool:
         return True
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 17985c44..a9a9b9dc 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -2,7 +2,7 @@
 Collection of environment-specific ObservationBuilder.
 """
 import pprint
-from typing import Optional, List
+from typing import Optional, List, Dict
 
 import numpy as np
 
@@ -46,7 +46,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     def reset(self):
         self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
 
-    def get_many(self, handles: Optional[List[int]] = None):
+    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, List[int]]:
         """
         Called whenever an observation has to be computed for the `env` environment, for each agent with handle
         in the `handles` list.
@@ -75,7 +75,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             observations[h] = self.get(h)
         return observations
 
-    def get(self, handle: int = 0):
+    def get(self, handle: int = 0) -> List[int]:
         """
         Computes the current observation for agent `handle` in env
 
@@ -534,7 +534,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
                 bitlist = [0] * (16 - len(bitlist)) + bitlist
                 self.rail_obs[i, j] = np.array(bitlist)
 
-    def get(self, handle: int = 0):
+    def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
         obs_targets = np.zeros((self.env.height, self.env.width, 2))
         obs_agents_state = np.zeros((self.env.height, self.env.width, 4))
         agents = self.env.agents
@@ -600,7 +600,7 @@ class LocalObsForRailEnv(ObservationBuilder):
                 bitlist = [0] * (16 - len(bitlist)) + bitlist
                 self.rail_obs[i, j] = np.array(bitlist)
 
-    def get(self, handle: int = 0):
+    def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
         agents = self.env.agents
         agent = agents[handle]
 
@@ -640,7 +640,7 @@ class LocalObsForRailEnv(ObservationBuilder):
         direction = np.identity(4)[agent.direction]
         return local_rail_obs, obs_map_state, obs_other_agents_state, direction
 
-    def get_many(self, handles: Optional[List[int]] = None):
+    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, (np.ndarray, np.ndarray, np.ndarray, np.ndarray)]:
         """
         Called whenever an observation has to be computed for the `env` environment, for each agent with handle
         in the `handles` list.
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 1b3c6ade..33e5bb40 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -1,5 +1,5 @@
 import random
-from typing import Dict
+from typing import Dict, List
 
 import numpy as np
 
@@ -31,7 +31,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
         # Recompute the distance map, if the environment has changed.
         super().reset()
 
-    def get(self, handle: int = 0):
+    def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
 
         possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
-- 
GitLab