diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 0420ab70737503ec135eb05881b4b9a82688f9df..ecc0f8f18d3a2d9cb730c9fcadf237e7ce133a6d 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -4,6 +4,7 @@ Collection of environment-specific PredictionBuilder. import numpy as np +from envs.rail_env import RailEnvActions from flatland.core.env_prediction_builder import PredictionBuilder @@ -41,12 +42,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): prediction_dict = {} for agent in agents: - - # 0: do nothing - # 1: turn left and move to the next cell - # 2: move to the next cell in front of the agent - # 3: turn right and move to the next cell - action_priorities = [2, 1, 3] + action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT] _agent_initial_position = agent.position _agent_initial_direction = agent.direction prediction = np.zeros(shape=(self.max_depth, 5))