diff --git a/env_data/tests/test_002.pkl b/env_data/tests/test_002.pkl new file mode 100644 index 0000000000000000000000000000000000000000..e46a5c088ef555109c352b7a4d5eaf8dfbaf8700 Binary files /dev/null and b/env_data/tests/test_002.pkl differ diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py index c9c6b00375ef4577880e2b8c98c2ff9dc946a7fa..22721407f059ff2e02d907110cb9f982aa9d599e 100644 --- a/flatland/envs/distance_map.py +++ b/flatland/envs/distance_map.py @@ -18,27 +18,21 @@ class DistanceMap: self.agents: List[EnvAgent] = agents self.rail: Optional[GridTransitionMap] = None - """ - Set the distance map - """ def set(self, distance_map: np.ndarray): + """ + Set the distance map + """ self.distance_map = distance_map - """ - Get the distance map - """ def get(self) -> np.ndarray: - + """ + Get the distance map + """ if self.reset_was_called: self.reset_was_called = False nb_agents = len(self.agents) compute_distance_map = True - if self.agents_previous_computation is not None and nb_agents == len(self.agents_previous_computation): - compute_distance_map = False - for i in range(nb_agents): - if self.agents[i].target != self.agents_previous_computation[i].target: - compute_distance_map = True # Don't compute the distance map if it was loaded if self.agents_previous_computation is None and self.distance_map is not None: compute_distance_map = False @@ -51,12 +45,12 @@ class DistanceMap: return self.distance_map - """ - Reset the distance map - """ def reset(self, agents: List[EnvAgent], rail: GridTransitionMap): + """ + Reset the distance map + """ self.reset_was_called = True - self.agents = agents + self.agents: List[EnvAgent] = agents self.rail = rail self.env_height = rail.height self.env_width = rail.width @@ -110,7 +104,8 @@ class DistanceMap: return max_distance - def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance, enforce_target_direction=-1): + def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance, + enforce_target_direction=-1): """ Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the minimum distances from each target cell. @@ -134,8 +129,7 @@ class DistanceMap: for agent_orientation in range(4): # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation), - desired_movement_from_new_cell) - # is_valid = True + desired_movement_from_new_cell) if is_valid: """ diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 77707b9f110376ddf2638b830830ff1a1c1edbf6..76095a2a2e1d9532951600118c6a777612641101 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -5,8 +5,9 @@ Collection of environment-specific PredictionBuilder. import numpy as np from flatland.core.env_prediction_builder import PredictionBuilder -from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env import RailEnvActions +from flatland.envs.rail_env_shortest_paths import get_shortest_paths from flatland.utils.ordered_set import OrderedSet @@ -59,7 +60,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): continue for action in action_priorities: - cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ + cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \ self.env._check_action_on_agent(action, agent) if all([new_cell_isValid, transition_isValid]): # move and change direction to face the new_direction that was @@ -92,6 +93,9 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): """ Called whenever get_many in the observation build is called. Requires distance_map to extract the shortest path. + Does not take into account future positions of other agents! + + If there is no shortest path, the agent just stands still and stops moving. Parameters ---------- @@ -106,14 +110,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): - position axis 0 - position axis 1 - direction - - action taken to come here + - action taken to come here (not implemented yet) The prediction at 0 is the current position, direction etc. """ agents = self.env.agents if handle: agents = [self.env.agents[handle]] - distance_map = self.env.distance_map - assert distance_map is not None + distance_map: DistanceMap = self.env.distance_map + + shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth) prediction_dict = {} for agent in agents: @@ -123,52 +128,35 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): times_per_cell = int(np.reciprocal(agent_speed)) prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] + + shortest_path = shortest_paths[agent.handle] + + # if there is a shortest path, remove the initial position + if shortest_path: + shortest_path = shortest_path[1:] + new_direction = _agent_initial_direction new_position = _agent_initial_position visited = OrderedSet() for index in range(1, self.max_depth + 1): - # if we're at the target, stop moving... - if agent.position == agent.target: - prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING] - visited.add((agent.position[0], agent.position[1], agent.direction)) - continue - if not agent.moving: - prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING] - visited.add((agent.position[0], agent.position[1], agent.direction)) + # if we're at the target or not moving, stop moving until max_depth is reached + if new_position == agent.target or not agent.moving or not shortest_path: + prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING] + visited.add((*new_position, agent.direction)) continue - # Take shortest possible path - cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) - - if np.sum(cell_transitions) == 1 and index % times_per_cell == 0: - new_direction = np.argmax(cell_transitions) - new_position = get_new_position(agent.position, new_direction) - elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0: - min_dist = np.inf - no_dist_found = True - for direction in range(4): - if cell_transitions[direction] == 1: - neighbour_cell = get_new_position(agent.position, direction) - target_dist = distance_map.get()[agent.handle, neighbour_cell[0], neighbour_cell[1], direction] - if target_dist < min_dist or no_dist_found: - min_dist = target_dist - new_direction = direction - no_dist_found = False - new_position = get_new_position(agent.position, new_direction) - elif index % times_per_cell == 0: - raise Exception("No transition possible {}".format(cell_transitions)) - - # update the agent's position and direction - agent.position = new_position - agent.direction = new_direction + + if index % times_per_cell == 0: + new_position = shortest_path[0].position + new_direction = shortest_path[0].direction + + shortest_path = shortest_path[1:] # prediction is ready prediction[index] = [index, *new_position, new_direction, 0] - visited.add((new_position[0], new_position[1], new_direction)) + visited.add((*new_position, new_direction)) + + # TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env! self.env.dev_pred_dict[agent.handle] = visited prediction_dict[agent.handle] = prediction - # cleanup: reset initial position - agent.position = _agent_initial_position - agent.direction = _agent_initial_direction - return prediction_dict diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index d0add3086014c7ad07c29e01588cba380c26cbd7..b5bc44f2e698a4ffa23ca31d34ac14f613340d04 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -4,7 +4,7 @@ Definition of the RailEnv environment. # TODO: _ this is a global method --> utils or remove later import warnings from enum import IntEnum -from typing import List, Set, NamedTuple, Optional, Tuple, Dict +from typing import List, NamedTuple, Optional, Tuple, Dict import msgpack import msgpack_numpy as m @@ -20,7 +20,6 @@ from flatland.envs.distance_map import DistanceMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator -from flatland.utils.ordered_set import OrderedSet m.patch() @@ -241,9 +240,6 @@ class RailEnv(Environment): # can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition? rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) - if optionals and 'distance_map' in optionals: - self.distance_map.set(optionals['distance_map']) - if regen_rail or self.rail is None: self.rail = rail self.height, self.width = self.rail.grid.shape @@ -253,6 +249,11 @@ class RailEnv(Environment): check = self.rail.cell_neighbours_valid(rc_pos, True) if not check: warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check)) + # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172 + # hacky: we must re-compute the distance map and not use the initial distance_map loaded from file by + # rail_from_file!!! + elif optionals and 'distance_map' in optionals: + self.distance_map.set(optionals['distance_map']) if replace_agents: agents_hints = None @@ -587,60 +588,6 @@ class RailEnv(Environment): transition_valid = True return new_direction, transition_valid - @staticmethod - def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, - agent_position: Tuple[int, int], - rail: GridTransitionMap) -> Set[RailEnvNextAction]: - """ - Get the valid move actions (forward, left, right) for an agent. - - Parameters - ---------- - agent_direction : Grid4TransitionsEnum - agent_position: Tuple[int,int] - rail : GridTransitionMap - - - Returns - ------- - Set of `RailEnvNextAction` (tuples of (action,position,direction)) - Possible move actions (forward,left,right) and the next position/direction they lead to. - It is not checked that the next cell is free. - """ - valid_actions: Set[RailEnvNextAction] = OrderedSet() - possible_transitions = rail.get_transitions(*agent_position, agent_direction) - num_transitions = np.count_nonzero(possible_transitions) - # Start from the current orientation, and see which transitions are available; - # organize them as [left, forward, right], relative to the current orientation - # If only one transition is possible, the forward branch is aligned with it. - if rail.is_dead_end(agent_position): - action = RailEnvActions.MOVE_FORWARD - exit_direction = (agent_direction + 2) % 4 - if possible_transitions[exit_direction]: - new_position = get_new_position(agent_position, exit_direction) - valid_actions.add(RailEnvNextAction(action, new_position, exit_direction)) - elif num_transitions == 1: - action = RailEnvActions.MOVE_FORWARD - for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: - if possible_transitions[new_direction]: - new_position = get_new_position(agent_position, new_direction) - valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) - else: - for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: - if possible_transitions[new_direction]: - if new_direction == agent_direction: - action = RailEnvActions.MOVE_FORWARD - elif new_direction == (agent_direction + 1) % 4: - action = RailEnvActions.MOVE_RIGHT - elif new_direction == (agent_direction - 1) % 4: - action = RailEnvActions.MOVE_LEFT - else: - raise Exception("Illegal state") - - new_position = get_new_position(agent_position, new_direction) - valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) - return valid_actions - def _get_observations(self): self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py new file mode 100644 index 0000000000000000000000000000000000000000..793601d4d18ac38b729d15883089d5acbfc41ed3 --- /dev/null +++ b/flatland/envs/rail_env_shortest_paths.py @@ -0,0 +1,140 @@ +import math +from typing import Dict, List, Optional, NamedTuple, Tuple, Set + +import matplotlib.pyplot as plt +import numpy as np + +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.distance_map import DistanceMap +from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions +from flatland.utils.ordered_set import OrderedSet + +WalkingElement = \ + NamedTuple('WalkingElement', + [('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)]) + + +def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, + agent_position: Tuple[int, int], + rail: GridTransitionMap) -> Set[RailEnvNextAction]: + """ + Get the valid move actions (forward, left, right) for an agent. + + Parameters + ---------- + agent_direction : Grid4TransitionsEnum + agent_position: Tuple[int,int] + rail : GridTransitionMap + + + Returns + ------- + Set of `RailEnvNextAction` (tuples of (action,position,direction)) + Possible move actions (forward,left,right) and the next position/direction they lead to. + It is not checked that the next cell is free. + """ + valid_actions: Set[RailEnvNextAction] = OrderedSet() + possible_transitions = rail.get_transitions(*agent_position, agent_direction) + num_transitions = np.count_nonzero(possible_transitions) + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right], relative to the current orientation + # If only one transition is possible, the forward branch is aligned with it. + if rail.is_dead_end(agent_position): + action = RailEnvActions.MOVE_FORWARD + exit_direction = (agent_direction + 2) % 4 + if possible_transitions[exit_direction]: + new_position = get_new_position(agent_position, exit_direction) + valid_actions.add(RailEnvNextAction(action, new_position, exit_direction)) + elif num_transitions == 1: + action = RailEnvActions.MOVE_FORWARD + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + new_position = get_new_position(agent_position, new_direction) + valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) + else: + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + if new_direction == agent_direction: + action = RailEnvActions.MOVE_FORWARD + elif new_direction == (agent_direction + 1) % 4: + action = RailEnvActions.MOVE_RIGHT + elif new_direction == (agent_direction - 1) % 4: + action = RailEnvActions.MOVE_LEFT + else: + raise Exception("Illegal state") + + new_position = get_new_position(agent_position, new_direction) + valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) + return valid_actions + + +# N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!) +def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = None) \ + -> Dict[int, Optional[List[WalkingElement]]]: + """ + Computes the shortest path for each agent to its target and the action to be taken to do so. + The paths are derived from a `DistanceMap`. + + If there is no path (rail disconnected), the path is given as None. + The agent state (moving or not) and its speed are not taken into account + + Parameters + ---------- + distance_map + + Returns + ------- + Dict[int, Optional[List[WalkingElement]]] + + """ + shortest_paths = dict() + + def _shortest_path_for_agent(agent): + position = agent.position + direction = agent.direction + shortest_paths[agent.handle] = [] + distance = math.inf + depth = 0 + while (position != agent.target and (max_depth is None or depth < max_depth)): + next_actions = get_valid_move_actions_(direction, position, distance_map.rail) + best_next_action = None + for next_action in next_actions: + next_action_distance = distance_map.get()[ + agent.handle, next_action.next_position[0], next_action.next_position[ + 1], next_action.next_direction] + if next_action_distance < distance: + best_next_action = next_action + distance = next_action_distance + + shortest_paths[agent.handle].append(WalkingElement(position, direction, best_next_action)) + depth += 1 + + # if there is no way to continue, the rail must be disconnected! + # (or distance map is incorrect) + if best_next_action is None: + shortest_paths[agent.handle] = None + return + + position = best_next_action.next_position + direction = best_next_action.next_direction + if max_depth is None or depth < max_depth: + shortest_paths[agent.handle].append( + WalkingElement(position, direction, + RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction))) + + for agent in distance_map.agents: + _shortest_path_for_agent(agent) + + return shortest_paths + + +def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0): + if agent_handle >= distance_map.get().shape[0]: + print("Error: agent_handle cannot be larger than actual number of agents") + return + # take min value of all 4 directions + min_distance_map = np.min(distance_map.get(), axis=3) + plt.imshow(min_distance_map[agent_handle][:][:]) + plt.show() diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 69cfce764fe124d9e3eb05019e1c734f2285bdb5..dc1cff12c0c8b1860859208a13d6403734a2d2ad 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -1,7 +1,3 @@ -import numpy as np -import matplotlib.pyplot as plt - -from flatland.envs.distance_map import DistanceMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -21,13 +17,3 @@ def load_flatland_environment_from_file(file_name, load_from_package=None, obs_b schedule_generator=schedule_from_file(file_name, load_from_package), obs_builder_object=obs_builder_object) return environment - - -def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0): - if agent_handle >= distance_map.get().shape[0]: - print("Error: agent_handle cannot be larger than actual number of agents") - return - # take min value of all 4 directions - min_distance_map = np.min(distance_map.get(), axis=3) - plt.imshow(min_distance_map[agent_handle][:][:]) - plt.show() diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 92a0f84f35fa942b03236c6add6e722475a2d842..4dad2ca872725517ffd47308f088f098b6abe1aa 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -174,7 +174,6 @@ class PILGL(GraphicsLayer): self.draws[layer].text(xyPixLeftTop, strText, font=self.font, fill=(0, 0, 0, 255)) def text_rowcol(self, rcTopLeft, strText, layer=AGENT_LAYER): - print("Text:", "rc:", rcTopLeft, "text:", strText, "layer:", layer) xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]]) self.text(*xyPixLeftTop, strText, layer) @@ -606,7 +605,6 @@ class PILSVG(PILGL): self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER) if show_debug: - print("Call text:") self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx)) def set_cell_occupied(self, agent_idx, row, col): diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py index 6da29d7f6d1a52c42dd006b84f94a959990e0932..a12c26e66fdf5a9ff102bb79440bd4f4b805e819 100644 --- a/flatland/utils/simple_rail.py +++ b/flatland/utils/simple_rail.py @@ -45,6 +45,46 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]: return rail, rail_map +def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]: + # We instantiate a very simple rail network on a 7x10 grid: + # Note that that cells have invalid RailEnvTransitions! + # | + # | + # | + # _ _ _ _\ _ _ _ _ _ + # / + # | + # | + # | + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + simple_switch_north_left = cells[2] + simple_switch_north_right = cells[10] + simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) + simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [simple_switch_east_west_north] + + [dead_end_from_west] + [dead_end_from_east] + [simple_switch_east_west_south] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map + + + def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]: # We instantiate a very simple rail network on a 7x10 grid: # | diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index c31494673e63a17dc07eb6d89eeb581c640b1e13..7ee0fd4aadf72b12b591259a71af8b408145418f 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -7,7 +7,8 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction +from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool @@ -142,6 +143,21 @@ def test_shortest_path_predictor(rendering=False): 1], agent.direction] == 5.0, "found {} instead of {}".format( distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0) + paths = get_shortest_paths(env.distance_map)[0] + assert paths == [ + WalkingElement((5, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), + next_direction=0)), + WalkingElement((4, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), + next_direction=0)), + WalkingElement((3, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7), + next_direction=1)), + WalkingElement((3, 7), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8), + next_direction=1)), + WalkingElement((3, 8), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9), + next_direction=1)), + WalkingElement((3, 9), 1, RailEnvNextAction(action=RailEnvActions.STOP_MOVING, next_position=(3, 9), + next_direction=1))] + # extract the data predictions = env.obs_builder.predictions positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0]))) @@ -220,12 +236,13 @@ def test_shortest_path_predictor(rendering=False): [20.], ]) + assert np.array_equal(time_offsets, expected_time_offsets), \ + "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets) + assert np.array_equal(positions, expected_positions), \ "positions {}, expected {}".format(positions, expected_positions) assert np.array_equal(directions, expected_directions), \ "directions {}, expected {}".format(directions, expected_directions) - assert np.array_equal(time_offsets, expected_time_offsets), \ - "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets) def test_shortest_path_predictor_conflicts(rendering=False): diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py new file mode 100644 index 0000000000000000000000000000000000000000..4600c4a3002995e1238a0ccbda762501ac985408 --- /dev/null +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -0,0 +1,194 @@ +import numpy as np + +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import DummyPredictorForRailEnv +from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions, RailEnv +from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement +from flatland.envs.rail_env_utils import load_flatland_environment_from_file +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.simple_rail import make_disconnected_simple_rail + + +def test_get_shortest_paths_unreachable(): + rail, rail_map = make_disconnected_simple_rail() + + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), + ) + + # set the initial position + agent = env.agents_static[0] + agent.position = (3, 1) # west dead-end + agent.direction = Grid4TransitionsEnum.WEST + agent.target = (3, 9) # east dead-end + agent.moving = True + + # reset to set agents from agents_static + env.reset(False, False) + + actual = get_shortest_paths(env.distance_map) + expected = {0: None} + + assert actual == expected, "actual={},expected={}".format(actual, expected) + + +def test_get_shortest_paths(): + env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') + actual = get_shortest_paths(env.distance_map) + + expected = { + 0: [ + WalkingElement(position=(1, 1), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(1, 2), next_direction=1)), + WalkingElement(position=(1, 2), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(1, 3), next_direction=1)), + WalkingElement(position=(1, 3), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 3), next_direction=2)), + WalkingElement(position=(2, 3), direction=2, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 4), next_direction=1)), + WalkingElement(position=(2, 4), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 5), next_direction=1)), + WalkingElement(position=(2, 5), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 6), next_direction=1)), + WalkingElement(position=(2, 6), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 7), next_direction=1)), + WalkingElement(position=(2, 7), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 8), next_direction=1)), + WalkingElement(position=(2, 8), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 9), next_direction=1)), + WalkingElement(position=(2, 9), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 10), next_direction=1)), + WalkingElement(position=(2, 10), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 11), next_direction=1)), + WalkingElement(position=(2, 11), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 12), next_direction=1)), + WalkingElement(position=(2, 12), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 13), next_direction=1)), + WalkingElement(position=(2, 13), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 14), next_direction=1)), + WalkingElement(position=(2, 14), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 15), next_direction=1)), + WalkingElement(position=(2, 15), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 16), next_direction=1)), + WalkingElement(position=(2, 16), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 17), next_direction=1)), + WalkingElement(position=(2, 17), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 18), next_direction=1)), + WalkingElement(position=(2, 18), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.STOP_MOVING, + next_position=(2, 18), next_direction=1))], + 1: [ + WalkingElement(position=(3, 18), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(3, 17), next_direction=3)), + WalkingElement(position=(3, 17), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(3, 16), next_direction=3)), + WalkingElement(position=(3, 16), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 16), next_direction=0)), + WalkingElement(position=(2, 16), direction=0, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 15), next_direction=3)), + WalkingElement(position=(2, 15), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 14), next_direction=3)), + WalkingElement(position=(2, 14), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 13), next_direction=3)), + WalkingElement(position=(2, 13), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 12), next_direction=3)), + WalkingElement(position=(2, 12), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 11), next_direction=3)), + WalkingElement(position=(2, 11), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 10), next_direction=3)), + WalkingElement(position=(2, 10), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 9), next_direction=3)), + WalkingElement(position=(2, 9), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 8), next_direction=3)), + WalkingElement(position=(2, 8), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 7), next_direction=3)), + WalkingElement(position=(2, 7), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 6), next_direction=3)), + WalkingElement(position=(2, 6), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 5), next_direction=3)), + WalkingElement(position=(2, 5), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 4), next_direction=3)), + WalkingElement(position=(2, 4), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 3), next_direction=3)), + WalkingElement(position=(2, 3), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 2), next_direction=3)), + WalkingElement(position=(2, 2), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(2, 1), next_direction=3)), + WalkingElement(position=(2, 1), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.STOP_MOVING, + next_position=(2, 1), next_direction=3))] + } + + for agent_handle in expected: + assert np.array_equal(actual[agent_handle], expected[agent_handle]), \ + "[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle]) + + +def test_get_shortest_paths_max_depth(): + env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') + actual = get_shortest_paths(env.distance_map, max_depth=2) + + expected = { + 0: [ + WalkingElement(position=(1, 1), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(1, 2), next_direction=1)), + WalkingElement(position=(1, 2), direction=1, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(1, 3), next_direction=1)) + ], + 1: [ + WalkingElement(position=(3, 18), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(3, 17), next_direction=3)), + WalkingElement(position=(3, 17), direction=3, + next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, + next_position=(3, 16), next_direction=3)), + ] + } + + for agent_handle in expected: + assert np.array_equal(actual[agent_handle], expected[agent_handle]), \ + "[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])