From a5ea6879bae58540f84cc26613be903b7aa26906 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Fri, 20 Sep 2019 10:03:05 +0200
Subject: [PATCH] Refactoring: add type hint for method of PredictionBuilder

---
 flatland/core/env_prediction_builder.py | 5 +++--
 flatland/envs/predictions.py            | 4 ++--
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index c1026fe0..756068ef 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -8,6 +8,7 @@ If predictions are not required in every step or not for all agents, then
 + `get()` is called whenever an step has to be computed, potentially for each agent independently in \
 case of multi-agent environments.
 """
+from flatland.core.env import Environment
 
 
 class PredictionBuilder:
@@ -19,7 +20,7 @@ class PredictionBuilder:
     def __init__(self, max_depth: int = 20):
         self.max_depth = max_depth
 
-    def _set_env(self, env):
+    def _set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
@@ -28,7 +29,7 @@ class PredictionBuilder:
         """
         pass
 
-    def get(self, handle=0):
+    def get(self, handle: int = 0):
         """
         Called whenever get_many in the observation build is called.
 
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index ccf4b967..0b390f01 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -18,7 +18,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def get(self, handle=None):
+    def get(self, handle: int = None):
         """
         Called whenever get_many in the observation build is called.
 
@@ -91,7 +91,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         # Initialize with depth 20
         self.max_depth = max_depth
 
-    def get(self, handle=None):
+    def get(self, handle: int = None):
         """
         Called whenever get_many in the observation build is called.
         Requires distance_map to extract the shortest path.
-- 
GitLab