From 8f652893c93f23509c661e5ae9f0a94f8c9dfcd5 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Fri, 20 Sep 2019 10:01:16 +0200
Subject: [PATCH] Refactoring: add type hint for PredictionBuilder

---
 flatland/envs/observations.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 2a1c5220..31ad1643 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -6,6 +6,7 @@ import pprint
 import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.env_prediction_builder import PredictionBuilder
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
 from flatland.utils.ordered_set import OrderedSet
@@ -22,7 +23,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     For details about the features in the tree observation see the get() function.
     """
 
-    def __init__(self, max_depth, predictor=None):
+    def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
         super().__init__()
         self.max_depth = max_depth
         self.observation_dim = 11
-- 
GitLab