diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 0420ab70737503ec135eb05881b4b9a82688f9df..3338e68126ee9a198e0208e17b1986f1c9fde6c8 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder. import numpy as np from flatland.core.env_prediction_builder import PredictionBuilder +from flatland.envs.rail_env import RailEnvActions class DummyPredictorForRailEnv(PredictionBuilder): @@ -41,12 +42,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): prediction_dict = {} for agent in agents: - - # 0: do nothing - # 1: turn left and move to the next cell - # 2: move to the next cell in front of the agent - # 3: turn right and move to the next cell - action_priorities = [2, 1, 3] + action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT] _agent_initial_position = agent.position _agent_initial_direction = agent.direction prediction = np.zeros(shape=(self.max_depth, 5))