diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index e1b90b8ac4cdaa1784f9e056c5c4c45ad10fb1ff..0d5387dc90dc5d5dceadd50d12e909434bc1c123 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -133,6 +133,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): # Take shortest possible path cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) + new_position = None + new_direction = None if np.sum(cell_transitions) == 1: new_direction = np.argmax(cell_transitions) new_position = get_new_position(agent.position, new_direction) @@ -148,21 +150,21 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): else: raise Exception("No transition possible {}".format(cell_transitions)) - + # which action to take for the transition? action = None for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]: - cell_isFree, new_cell_isValid, new_direction, _new_position, transition_isValid = \ - self.env._check_action_on_agent(action, agent) + _, _, _new_direction, _new_position, _ = self.env._check_action_on_agent(_action, agent) if np.array_equal(_new_position, new_position): action = _action break assert action is not None + + # update the agent's position and direction agent.position = new_position agent.direction = new_direction + + # prediction is ready prediction[index] = [index, *new_position, new_direction, action] - action_done = True - if not action_done: - raise Exception("Cannot move further. Something is wrong") prediction_dict[agent.handle] = prediction # cleanup: reset initial position