From c7ca42e0328db238bc00e594e83068f4521e8033 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Mon, 17 Jun 2019 15:00:57 +0200 Subject: [PATCH] 66 shortest-path-predictor: cleanup and unit test; not working yet --- flatland/envs/env_utils.py | 11 +++++++---- flatland/envs/observations.py | 9 +++++---- flatland/envs/predictions.py | 34 ++++++++++++++++------------------ 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index c9595b76..ee2c2637 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -7,6 +7,8 @@ a GridTransitionMap object. import numpy as np +from flatland.core.transitions import Grid4TransitionsEnum + def get_direction(pos1, pos2): """ @@ -253,13 +255,14 @@ def distance_on_rail(pos1, pos2): def get_new_position(position, movement): - if movement == 0: # NORTH + """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ + if movement == Grid4TransitionsEnum.NORTH: return (position[0] - 1, position[1]) - elif movement == 1: # EAST + elif movement == Grid4TransitionsEnum.EAST: return (position[0], position[1] + 1) - elif movement == 2: # SOUTH + elif movement == Grid4TransitionsEnum.SOUTH: return (position[0] + 1, position[1]) - elif movement == 3: # WEST + elif movement == Grid4TransitionsEnum.WEST: return (position[0], position[1] - 1) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index d7fdcee7..a7f91f14 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -6,6 +6,7 @@ from collections import deque import numpy as np from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.transitions import Grid4TransitionsEnum from flatland.envs.env_utils import coordinate_to_position @@ -162,13 +163,13 @@ class TreeObsForRailEnv(ObservationBuilder): """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ - if movement == 0: # NORTH + if movement == Grid4TransitionsEnum.NORTH: return (position[0] - 1, position[1]) - elif movement == 1: # EAST + elif movement == Grid4TransitionsEnum.EAST: return (position[0], position[1] + 1) - elif movement == 2: # SOUTH + elif movement == Grid4TransitionsEnum.SOUTH: return (position[0] + 1, position[1]) - elif movement == 3: # WEST + elif movement == Grid4TransitionsEnum.WEST: return (position[0], position[1] - 1) def get_many(self, handles=[]): diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 3910fa1b..e1b90b8a 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder. import numpy as np from flatland.core.env_prediction_builder import PredictionBuilder +from flatland.envs.env_utils import get_new_position from flatland.envs.rail_env import RailEnvActions @@ -55,8 +56,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): action_done = False # if we're at the target, stop moving... if agent.position == agent.target: - prediction[index] = [index, *agent.target, agent.direction, - RailEnvActions.STOP_MOVING] + prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING] continue for action in action_priorities: @@ -135,7 +135,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): if np.sum(cell_transitions) == 1: new_direction = np.argmax(cell_transitions) - new_position = self._new_position(agent.position, new_direction) + new_position = get_new_position(agent.position, new_direction) elif np.sum(cell_transitions) > 1: min_dist = np.inf for direction in range(4): @@ -144,11 +144,22 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): if target_dist < min_dist: min_dist = target_dist new_direction = direction - new_position = self._new_position(agent.position, new_direction) + new_position = get_new_position(agent.position, new_direction) + else: + raise Exception("No transition possible {}".format(cell_transitions)) + + action = None + for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]: + cell_isFree, new_cell_isValid, new_direction, _new_position, transition_isValid = \ + self.env._check_action_on_agent(action, agent) + if np.array_equal(_new_position, new_position): + action = _action + break + assert action is not None agent.position = new_position agent.direction = new_direction - prediction[index] = [index, *new_position, new_direction, RailEnvActions.MOVE_FORWARD] + prediction[index] = [index, *new_position, new_direction, action] action_done = True if not action_done: raise Exception("Cannot move further. Something is wrong") @@ -159,16 +170,3 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): agent.direction = _agent_initial_direction return prediction_dict - - def _new_position(self, position, movement): - """ - Utility function that converts a compass movement over a 2D grid to new positions (r, c). - """ - if movement == 0: # NORTH - return (position[0] - 1, position[1]) - elif movement == 1: # EAST - return (position[0], position[1] + 1) - elif movement == 2: # SOUTH - return (position[0] + 1, position[1]) - elif movement == 3: # WEST - return (position[0], position[1] - 1) -- GitLab