Commit a5ea6879 authored by u229589's avatar u229589
Browse files

Refactoring: add type hint for method of PredictionBuilder

parent 8f652893
......@@ -8,6 +8,7 @@ If predictions are not required in every step or not for all agents, then
+ `get()` is called whenever an step has to be computed, potentially for each agent independently in \
case of multi-agent environments.
"""
from flatland.core.env import Environment
class PredictionBuilder:
......@@ -19,7 +20,7 @@ class PredictionBuilder:
def __init__(self, max_depth: int = 20):
self.max_depth = max_depth
def _set_env(self, env):
def _set_env(self, env: Environment):
self.env = env
def reset(self):
......@@ -28,7 +29,7 @@ class PredictionBuilder:
"""
pass
def get(self, handle=0):
def get(self, handle: int = 0):
"""
Called whenever get_many in the observation build is called.
......
......@@ -18,7 +18,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, handle=None):
def get(self, handle: int = None):
"""
Called whenever get_many in the observation build is called.
......@@ -91,7 +91,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
# Initialize with depth 20
self.max_depth = max_depth
def get(self, handle=None):
def get(self, handle: int = None):
"""
Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment