Skip to content
Snippets Groups Projects
Commit 8f652893 authored by u229589's avatar u229589
Browse files

Refactoring: add type hint for PredictionBuilder

parent c5bb4230
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ import pprint ...@@ -6,6 +6,7 @@ import pprint
import numpy as np import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.utils.ordered_set import OrderedSet from flatland.utils.ordered_set import OrderedSet
...@@ -22,7 +23,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -22,7 +23,7 @@ class TreeObsForRailEnv(ObservationBuilder):
For details about the features in the tree observation see the get() function. For details about the features in the tree observation see the get() function.
""" """
def __init__(self, max_depth, predictor=None): def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
super().__init__() super().__init__()
self.max_depth = max_depth self.max_depth = max_depth
self.observation_dim = 11 self.observation_dim = 11
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment