From 8216da27739f88b29a5a23f9917345db4986958f Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 24 Sep 2019 15:47:45 +0200 Subject: [PATCH] bugfix get-shortest-path --- flatland/envs/distance_map.py | 22 +++++------ flatland/envs/rail_env_utils.py | 12 ++++-- tests/test_flatland_envs_predictions.py | 4 +- tests/test_shortest_path.py | 52 +++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 17 deletions(-) create mode 100644 tests/test_shortest_path.py diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py index 52ab5fad..228d3992 100644 --- a/flatland/envs/distance_map.py +++ b/flatland/envs/distance_map.py @@ -18,19 +18,18 @@ class DistanceMap: self.agents: List[EnvAgent] = agents self.rail: Optional[GridTransitionMap] = None - """ - Set the distance map - """ - def set(self, distance_map: np.ndarray): + """ + Set the distance map + """ self.distance_map = distance_map - """ - Get the distance map - """ - def get(self) -> np.ndarray: + def get(self) -> np.ndarray: + """ + Get the distance map + """ if self.reset_was_called: self.reset_was_called = False @@ -53,11 +52,12 @@ class DistanceMap: return self.distance_map - """ - Reset the distance map - """ + def reset(self, agents: List[EnvAgent], rail: GridTransitionMap): + """ + Reset the distance map + """ self.reset_was_called = True self.agents: List[EnvAgent] = agents self.rail = rail diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 200c19dc..8e3ed5b4 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -83,22 +83,26 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, return valid_actions -def get_shorts_paths(distance_map: DistanceMap) -> Dict[int, List[RailEnvNextAction]]: +def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[RailEnvNextAction]]: + # TODO: do we need to support unreachable targets? + # TODO refactoring: unify with predictor (support agent.moving and max_depth) shortest_paths = dict() for a in distance_map.agents: position = a.position direction = a.direction shortest_paths[a.handle] = [] - + distance = math.inf while (position != a.target): next_actions = get_valid_move_actions_(direction, position, distance_map.rail) - best = math.inf best_next_action = None for next_action in next_actions: - if distance_map.get()[a.handle, position[0], position[1], direction] < best: + next_action_distance = distance_map.get()[a.handle, next_action.next_position[0], next_action.next_position[1], next_action.next_direction] + if next_action_distance < distance: best_next_action = next_action + distance = next_action_distance position = best_next_action.next_position direction = best_next_action.next_direction shortest_paths[a.handle].append(best_next_action) + return shortest_paths diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index b2a9b612..2df60017 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction -from flatland.envs.rail_env_utils import get_shorts_paths +from flatland.envs.rail_env_utils import get_shortest_paths from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool @@ -143,7 +143,7 @@ def test_shortest_path_predictor(rendering=False): 1], agent.direction] == 5.0, "found {} instead of {}".format( distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0) - paths = get_shorts_paths(env.distance_map)[0] + paths = get_shortest_paths(env.distance_map)[0] assert paths == [ RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), next_direction=0), RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), next_direction=0), diff --git a/tests/test_shortest_path.py b/tests/test_shortest_path.py new file mode 100644 index 00000000..46216436 --- /dev/null +++ b/tests/test_shortest_path.py @@ -0,0 +1,52 @@ +import numpy as np + +from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions +from flatland.envs.rail_env_utils import load_flatland_environment_from_file, get_shortest_paths + + +def test_get_shortest_paths(): + env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') + actual = get_shortest_paths(env.distance_map) + + expected = { + 0: [RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(1, 2), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(1, 3), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 3), next_direction=2), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 4), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 5), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 6), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 7), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 8), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 9), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 10), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 11), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 12), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 13), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 14), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 15), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 16), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 17), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 18), next_direction=1)], + 1: [RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 17), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 16), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 16), next_direction=0), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 15), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 14), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 13), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 12), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 11), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 10), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 9), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 8), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 7), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 6), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 5), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 4), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 3), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 2), next_direction=3), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 1), next_direction=3)] + } + + for agent_handle in expected: + assert np.array_equal(actual[agent_handle], expected[agent_handle]), \ + "[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle]) -- GitLab