From 8ab30c41dd253631d2dfb6477e37da6bfa7c90af Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Thu, 26 Sep 2019 12:45:58 +0200 Subject: [PATCH] refactoring get shortest path in predictor; support max_depth an disconnected grids --- flatland/envs/predictions.py | 72 ++++----- flatland/envs/rail_env_shortest_paths.py | 140 ++++++++++++++++++ flatland/envs/rail_env_utils.py | 112 +------------- flatland/utils/simple_rail.py | 40 +++++ tests/test_flatland_envs_predictions.py | 7 +- ...atland_envs_rail_env_shortest_paths.py.py} | 66 ++++++++- 6 files changed, 279 insertions(+), 158 deletions(-) create mode 100644 flatland/envs/rail_env_shortest_paths.py rename tests/{test_shortest_path.py => test_flatland_envs_rail_env_shortest_paths.py.py} (78%) diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 77707b9f..76095a2a 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -5,8 +5,9 @@ Collection of environment-specific PredictionBuilder. import numpy as np from flatland.core.env_prediction_builder import PredictionBuilder -from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env import RailEnvActions +from flatland.envs.rail_env_shortest_paths import get_shortest_paths from flatland.utils.ordered_set import OrderedSet @@ -59,7 +60,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): continue for action in action_priorities: - cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ + cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \ self.env._check_action_on_agent(action, agent) if all([new_cell_isValid, transition_isValid]): # move and change direction to face the new_direction that was @@ -92,6 +93,9 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): """ Called whenever get_many in the observation build is called. Requires distance_map to extract the shortest path. + Does not take into account future positions of other agents! + + If there is no shortest path, the agent just stands still and stops moving. Parameters ---------- @@ -106,14 +110,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): - position axis 0 - position axis 1 - direction - - action taken to come here + - action taken to come here (not implemented yet) The prediction at 0 is the current position, direction etc. """ agents = self.env.agents if handle: agents = [self.env.agents[handle]] - distance_map = self.env.distance_map - assert distance_map is not None + distance_map: DistanceMap = self.env.distance_map + + shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth) prediction_dict = {} for agent in agents: @@ -123,52 +128,35 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): times_per_cell = int(np.reciprocal(agent_speed)) prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] + + shortest_path = shortest_paths[agent.handle] + + # if there is a shortest path, remove the initial position + if shortest_path: + shortest_path = shortest_path[1:] + new_direction = _agent_initial_direction new_position = _agent_initial_position visited = OrderedSet() for index in range(1, self.max_depth + 1): - # if we're at the target, stop moving... - if agent.position == agent.target: - prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING] - visited.add((agent.position[0], agent.position[1], agent.direction)) - continue - if not agent.moving: - prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING] - visited.add((agent.position[0], agent.position[1], agent.direction)) + # if we're at the target or not moving, stop moving until max_depth is reached + if new_position == agent.target or not agent.moving or not shortest_path: + prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING] + visited.add((*new_position, agent.direction)) continue - # Take shortest possible path - cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) - - if np.sum(cell_transitions) == 1 and index % times_per_cell == 0: - new_direction = np.argmax(cell_transitions) - new_position = get_new_position(agent.position, new_direction) - elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0: - min_dist = np.inf - no_dist_found = True - for direction in range(4): - if cell_transitions[direction] == 1: - neighbour_cell = get_new_position(agent.position, direction) - target_dist = distance_map.get()[agent.handle, neighbour_cell[0], neighbour_cell[1], direction] - if target_dist < min_dist or no_dist_found: - min_dist = target_dist - new_direction = direction - no_dist_found = False - new_position = get_new_position(agent.position, new_direction) - elif index % times_per_cell == 0: - raise Exception("No transition possible {}".format(cell_transitions)) - - # update the agent's position and direction - agent.position = new_position - agent.direction = new_direction + + if index % times_per_cell == 0: + new_position = shortest_path[0].position + new_direction = shortest_path[0].direction + + shortest_path = shortest_path[1:] # prediction is ready prediction[index] = [index, *new_position, new_direction, 0] - visited.add((new_position[0], new_position[1], new_direction)) + visited.add((*new_position, new_direction)) + + # TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env! self.env.dev_pred_dict[agent.handle] = visited prediction_dict[agent.handle] = prediction - # cleanup: reset initial position - agent.position = _agent_initial_position - agent.direction = _agent_initial_direction - return prediction_dict diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py new file mode 100644 index 00000000..793601d4 --- /dev/null +++ b/flatland/envs/rail_env_shortest_paths.py @@ -0,0 +1,140 @@ +import math +from typing import Dict, List, Optional, NamedTuple, Tuple, Set + +import matplotlib.pyplot as plt +import numpy as np + +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.distance_map import DistanceMap +from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions +from flatland.utils.ordered_set import OrderedSet + +WalkingElement = \ + NamedTuple('WalkingElement', + [('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)]) + + +def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, + agent_position: Tuple[int, int], + rail: GridTransitionMap) -> Set[RailEnvNextAction]: + """ + Get the valid move actions (forward, left, right) for an agent. + + Parameters + ---------- + agent_direction : Grid4TransitionsEnum + agent_position: Tuple[int,int] + rail : GridTransitionMap + + + Returns + ------- + Set of `RailEnvNextAction` (tuples of (action,position,direction)) + Possible move actions (forward,left,right) and the next position/direction they lead to. + It is not checked that the next cell is free. + """ + valid_actions: Set[RailEnvNextAction] = OrderedSet() + possible_transitions = rail.get_transitions(*agent_position, agent_direction) + num_transitions = np.count_nonzero(possible_transitions) + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right], relative to the current orientation + # If only one transition is possible, the forward branch is aligned with it. + if rail.is_dead_end(agent_position): + action = RailEnvActions.MOVE_FORWARD + exit_direction = (agent_direction + 2) % 4 + if possible_transitions[exit_direction]: + new_position = get_new_position(agent_position, exit_direction) + valid_actions.add(RailEnvNextAction(action, new_position, exit_direction)) + elif num_transitions == 1: + action = RailEnvActions.MOVE_FORWARD + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + new_position = get_new_position(agent_position, new_direction) + valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) + else: + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + if new_direction == agent_direction: + action = RailEnvActions.MOVE_FORWARD + elif new_direction == (agent_direction + 1) % 4: + action = RailEnvActions.MOVE_RIGHT + elif new_direction == (agent_direction - 1) % 4: + action = RailEnvActions.MOVE_LEFT + else: + raise Exception("Illegal state") + + new_position = get_new_position(agent_position, new_direction) + valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) + return valid_actions + + +# N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!) +def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = None) \ + -> Dict[int, Optional[List[WalkingElement]]]: + """ + Computes the shortest path for each agent to its target and the action to be taken to do so. + The paths are derived from a `DistanceMap`. + + If there is no path (rail disconnected), the path is given as None. + The agent state (moving or not) and its speed are not taken into account + + Parameters + ---------- + distance_map + + Returns + ------- + Dict[int, Optional[List[WalkingElement]]] + + """ + shortest_paths = dict() + + def _shortest_path_for_agent(agent): + position = agent.position + direction = agent.direction + shortest_paths[agent.handle] = [] + distance = math.inf + depth = 0 + while (position != agent.target and (max_depth is None or depth < max_depth)): + next_actions = get_valid_move_actions_(direction, position, distance_map.rail) + best_next_action = None + for next_action in next_actions: + next_action_distance = distance_map.get()[ + agent.handle, next_action.next_position[0], next_action.next_position[ + 1], next_action.next_direction] + if next_action_distance < distance: + best_next_action = next_action + distance = next_action_distance + + shortest_paths[agent.handle].append(WalkingElement(position, direction, best_next_action)) + depth += 1 + + # if there is no way to continue, the rail must be disconnected! + # (or distance map is incorrect) + if best_next_action is None: + shortest_paths[agent.handle] = None + return + + position = best_next_action.next_position + direction = best_next_action.next_direction + if max_depth is None or depth < max_depth: + shortest_paths[agent.handle].append( + WalkingElement(position, direction, + RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction))) + + for agent in distance_map.agents: + _shortest_path_for_agent(agent) + + return shortest_paths + + +def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0): + if agent_handle >= distance_map.get().shape[0]: + print("Error: agent_handle cannot be larger than actual number of agents") + return + # take min value of all 4 directions + min_distance_map = np.min(distance_map.get(), axis=3) + plt.imshow(min_distance_map[agent_handle][:][:]) + plt.show() diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 0e305f48..dc1cff12 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -1,23 +1,8 @@ -import math -from typing import Tuple, Set, Dict, List, NamedTuple - -import matplotlib.pyplot as plt -import numpy as np - -from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.core.grid.grid4_utils import get_new_position -from flatland.core.transition_map import GridTransitionMap -from flatland.envs.distance_map import DistanceMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnv, RailEnvNextAction, RailEnvActions +from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_file from flatland.envs.schedule_generators import schedule_from_file -from flatland.utils.ordered_set import OrderedSet - -WalkingElement = \ - NamedTuple('WalkingElement', - [('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)]) def load_flatland_environment_from_file(file_name, load_from_package=None, obs_builder_object=None): @@ -32,98 +17,3 @@ def load_flatland_environment_from_file(file_name, load_from_package=None, obs_b schedule_generator=schedule_from_file(file_name, load_from_package), obs_builder_object=obs_builder_object) return environment - - -def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, - agent_position: Tuple[int, int], - rail: GridTransitionMap) -> Set[RailEnvNextAction]: - """ - Get the valid move actions (forward, left, right) for an agent. - - Parameters - ---------- - agent_direction : Grid4TransitionsEnum - agent_position: Tuple[int,int] - rail : GridTransitionMap - - - Returns - ------- - Set of `RailEnvNextAction` (tuples of (action,position,direction)) - Possible move actions (forward,left,right) and the next position/direction they lead to. - It is not checked that the next cell is free. - """ - valid_actions: Set[RailEnvNextAction] = OrderedSet() - possible_transitions = rail.get_transitions(*agent_position, agent_direction) - num_transitions = np.count_nonzero(possible_transitions) - # Start from the current orientation, and see which transitions are available; - # organize them as [left, forward, right], relative to the current orientation - # If only one transition is possible, the forward branch is aligned with it. - if rail.is_dead_end(agent_position): - action = RailEnvActions.MOVE_FORWARD - exit_direction = (agent_direction + 2) % 4 - if possible_transitions[exit_direction]: - new_position = get_new_position(agent_position, exit_direction) - valid_actions.add(RailEnvNextAction(action, new_position, exit_direction)) - elif num_transitions == 1: - action = RailEnvActions.MOVE_FORWARD - for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: - if possible_transitions[new_direction]: - new_position = get_new_position(agent_position, new_direction) - valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) - else: - for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: - if possible_transitions[new_direction]: - if new_direction == agent_direction: - action = RailEnvActions.MOVE_FORWARD - elif new_direction == (agent_direction + 1) % 4: - action = RailEnvActions.MOVE_RIGHT - elif new_direction == (agent_direction - 1) % 4: - action = RailEnvActions.MOVE_LEFT - else: - raise Exception("Illegal state") - - new_position = get_new_position(agent_position, new_direction) - valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) - return valid_actions - - -def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[WalkingElement]]: - # TODO: do we need to support unreachable targets? - # TODO refactoring: unify with predictor (support agent.moving and max_depth) - shortest_paths = dict() - for a in distance_map.agents: - position = a.position - direction = a.direction - shortest_paths[a.handle] = [] - distance = math.inf - while (position != a.target): - next_actions = get_valid_move_actions_(direction, position, distance_map.rail) - best_next_action = None - for next_action in next_actions: - next_action_distance = distance_map.get()[ - a.handle, next_action.next_position[0], next_action.next_position[1], next_action.next_direction] - if next_action_distance < distance: - best_next_action = next_action - distance = next_action_distance - - shortest_paths[a.handle].append(WalkingElement(position, direction, best_next_action)) - - position = best_next_action.next_position - direction = best_next_action.next_direction - - shortest_paths[a.handle].append( - WalkingElement(position, direction, - RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction))) - - return shortest_paths - - -def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0): - if agent_handle >= distance_map.get().shape[0]: - print("Error: agent_handle cannot be larger than actual number of agents") - return - # take min value of all 4 directions - min_distance_map = np.min(distance_map.get(), axis=3) - plt.imshow(min_distance_map[agent_handle][:][:]) - plt.show() diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py index 6da29d7f..a12c26e6 100644 --- a/flatland/utils/simple_rail.py +++ b/flatland/utils/simple_rail.py @@ -45,6 +45,46 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]: return rail, rail_map +def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]: + # We instantiate a very simple rail network on a 7x10 grid: + # Note that that cells have invalid RailEnvTransitions! + # | + # | + # | + # _ _ _ _\ _ _ _ _ _ + # / + # | + # | + # | + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + simple_switch_north_left = cells[2] + simple_switch_north_right = cells[10] + simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) + simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [simple_switch_east_west_north] + + [dead_end_from_west] + [dead_end_from_east] + [simple_switch_east_west_south] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map + + + def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]: # We instantiate a very simple rail network on a 7x10 grid: # | diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 569cd3ad..7ee0fd4a 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction -from flatland.envs.rail_env_utils import get_shortest_paths, WalkingElement +from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool @@ -236,12 +236,13 @@ def test_shortest_path_predictor(rendering=False): [20.], ]) + assert np.array_equal(time_offsets, expected_time_offsets), \ + "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets) + assert np.array_equal(positions, expected_positions), \ "positions {}, expected {}".format(positions, expected_positions) assert np.array_equal(directions, expected_directions), \ "directions {}, expected {}".format(directions, expected_directions) - assert np.array_equal(time_offsets, expected_time_offsets), \ - "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets) def test_shortest_path_predictor_conflicts(rendering=False): diff --git a/tests/test_shortest_path.py b/tests/test_flatland_envs_rail_env_shortest_paths.py.py similarity index 78% rename from tests/test_shortest_path.py rename to tests/test_flatland_envs_rail_env_shortest_paths.py.py index 094b5603..4600c4a3 100644 --- a/tests/test_shortest_path.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py.py @@ -1,7 +1,41 @@ import numpy as np -from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions -from flatland.envs.rail_env_utils import load_flatland_environment_from_file, get_shortest_paths, WalkingElement +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import DummyPredictorForRailEnv +from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions, RailEnv +from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement +from flatland.envs.rail_env_utils import load_flatland_environment_from_file +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.simple_rail import make_disconnected_simple_rail + + +def test_get_shortest_paths_unreachable(): + rail, rail_map = make_disconnected_simple_rail() + + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), + ) + + # set the initial position + agent = env.agents_static[0] + agent.position = (3, 1) # west dead-end + agent.direction = Grid4TransitionsEnum.WEST + agent.target = (3, 9) # east dead-end + agent.moving = True + + # reset to set agents from agents_static + env.reset(False, False) + + actual = get_shortest_paths(env.distance_map) + expected = {0: None} + + assert actual == expected, "actual={},expected={}".format(actual, expected) def test_get_shortest_paths(): @@ -130,3 +164,31 @@ def test_get_shortest_paths(): for agent_handle in expected: assert np.array_equal(actual[agent_handle], expected[agent_handle]), \ "[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle]) + + +def test_get_shortest_paths_max_depth(): + env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') + actual = get_shortest_paths(env.distance_map, max_depth=2) + + expected = { + 0: [ + WalkingElement(position=(1, 1), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(1, 2), next_direction=1)), + WalkingElement(position=(1, 2), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(1, 3), next_direction=1)) + ], + 1: [ + WalkingElement(position=(3, 18), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(3, 17), next_direction=3)), + WalkingElement(position=(3, 17), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(3, 16), next_direction=3)), + ] + } + + for agent_handle in expected: + assert np.array_equal(actual[agent_handle], expected[agent_handle]), \ + "[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle]) -- GitLab