diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py index c1026fe02bfc2cfb30fd57ef570022ca3b15f6f0..756068efffa995968322bfecfefee178717a226b 100644 --- a/flatland/core/env_prediction_builder.py +++ b/flatland/core/env_prediction_builder.py @@ -8,6 +8,7 @@ If predictions are not required in every step or not for all agents, then + `get()` is called whenever an step has to be computed, potentially for each agent independently in \ case of multi-agent environments. """ +from flatland.core.env import Environment class PredictionBuilder: @@ -19,7 +20,7 @@ class PredictionBuilder: def __init__(self, max_depth: int = 20): self.max_depth = max_depth - def _set_env(self, env): + def _set_env(self, env: Environment): self.env = env def reset(self): @@ -28,7 +29,7 @@ class PredictionBuilder: """ pass - def get(self, handle=0): + def get(self, handle: int = 0): """ Called whenever get_many in the observation build is called. diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index ccf4b967c3eabe2ff85dc4084720aa8fc3ca9628..0b390f01c7d8ef876d3b30e58e0fd48f5aceecdf 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -18,7 +18,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): The prediction acts as if no other agent is in the environment and always takes the forward action. """ - def get(self, handle=None): + def get(self, handle: int = None): """ Called whenever get_many in the observation build is called. @@ -91,7 +91,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): # Initialize with depth 20 self.max_depth = max_depth - def get(self, handle=None): + def get(self, handle: int = None): """ Called whenever get_many in the observation build is called. Requires distance_map to extract the shortest path.