diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index c9595b7693497daa7db110b8fc8b4ae040d39cc9..ee2c263711906370900c14a6030d57ada573c2ea 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -7,6 +7,8 @@ a GridTransitionMap object. import numpy as np +from flatland.core.transitions import Grid4TransitionsEnum + def get_direction(pos1, pos2): """ @@ -253,13 +255,14 @@ def distance_on_rail(pos1, pos2): def get_new_position(position, movement): - if movement == 0: # NORTH + """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ + if movement == Grid4TransitionsEnum.NORTH: return (position[0] - 1, position[1]) - elif movement == 1: # EAST + elif movement == Grid4TransitionsEnum.EAST: return (position[0], position[1] + 1) - elif movement == 2: # SOUTH + elif movement == Grid4TransitionsEnum.SOUTH: return (position[0] + 1, position[1]) - elif movement == 3: # WEST + elif movement == Grid4TransitionsEnum.WEST: return (position[0], position[1] - 1) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index d7fdcee7cf0f1183f9430c08684e4ede468d188a..a7f91f1439f98bc2627700903b0175486f619749 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -6,6 +6,7 @@ from collections import deque import numpy as np from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.transitions import Grid4TransitionsEnum from flatland.envs.env_utils import coordinate_to_position @@ -162,13 +163,13 @@ class TreeObsForRailEnv(ObservationBuilder): """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ - if movement == 0: # NORTH + if movement == Grid4TransitionsEnum.NORTH: return (position[0] - 1, position[1]) - elif movement == 1: # EAST + elif movement == Grid4TransitionsEnum.EAST: return (position[0], position[1] + 1) - elif movement == 2: # SOUTH + elif movement == Grid4TransitionsEnum.SOUTH: return (position[0] + 1, position[1]) - elif movement == 3: # WEST + elif movement == Grid4TransitionsEnum.WEST: return (position[0], position[1] - 1) def get_many(self, handles=[]): diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 3910fa1b6a3deb5055841af4381957a0a474c3a7..e1b90b8ac4cdaa1784f9e056c5c4c45ad10fb1ff 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder. import numpy as np from flatland.core.env_prediction_builder import PredictionBuilder +from flatland.envs.env_utils import get_new_position from flatland.envs.rail_env import RailEnvActions @@ -55,8 +56,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): action_done = False # if we're at the target, stop moving... if agent.position == agent.target: - prediction[index] = [index, *agent.target, agent.direction, - RailEnvActions.STOP_MOVING] + prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING] continue for action in action_priorities: @@ -135,7 +135,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): if np.sum(cell_transitions) == 1: new_direction = np.argmax(cell_transitions) - new_position = self._new_position(agent.position, new_direction) + new_position = get_new_position(agent.position, new_direction) elif np.sum(cell_transitions) > 1: min_dist = np.inf for direction in range(4): @@ -144,11 +144,22 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): if target_dist < min_dist: min_dist = target_dist new_direction = direction - new_position = self._new_position(agent.position, new_direction) + new_position = get_new_position(agent.position, new_direction) + else: + raise Exception("No transition possible {}".format(cell_transitions)) + + action = None + for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]: + cell_isFree, new_cell_isValid, new_direction, _new_position, transition_isValid = \ + self.env._check_action_on_agent(action, agent) + if np.array_equal(_new_position, new_position): + action = _action + break + assert action is not None agent.position = new_position agent.direction = new_direction - prediction[index] = [index, *new_position, new_direction, RailEnvActions.MOVE_FORWARD] + prediction[index] = [index, *new_position, new_direction, action] action_done = True if not action_done: raise Exception("Cannot move further. Something is wrong") @@ -159,16 +170,3 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): agent.direction = _agent_initial_direction return prediction_dict - - def _new_position(self, position, movement): - """ - Utility function that converts a compass movement over a 2D grid to new positions (r, c). - """ - if movement == 0: # NORTH - return (position[0] - 1, position[1]) - elif movement == 1: # EAST - return (position[0], position[1] + 1) - elif movement == 2: # SOUTH - return (position[0] + 1, position[1]) - elif movement == 3: # WEST - return (position[0], position[1] - 1)