diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 2a1c52207630a72f2749ba22ab7c46241839d4ab..31ad16438cbf63a532b8e90bdbd9efb704e755b7 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -6,6 +6,7 @@ import pprint import numpy as np from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.env_prediction_builder import PredictionBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import coordinate_to_position from flatland.utils.ordered_set import OrderedSet @@ -22,7 +23,7 @@ class TreeObsForRailEnv(ObservationBuilder): For details about the features in the tree observation see the get() function. """ - def __init__(self, max_depth, predictor=None): + def __init__(self, max_depth: int, predictor: PredictionBuilder = None): super().__init__() self.max_depth = max_depth self.observation_dim = 11