diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py index c945ad7b779e1108945161920cb819f4e10f25e9..e19464b77cf09980616a88384542858625d138bb 100644 --- a/flatland/envs/rail_env_shortest_paths.py +++ b/flatland/envs/rail_env_shortest_paths.py @@ -71,6 +71,128 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, return valid_actions +def get_new_position_for_action( + agent_position: Tuple[int, int], + agent_direction: Grid4TransitionsEnum, + action: RailEnvNextAction, + rail: GridTransitionMap) -> Tuple[int, int, int]: + """ + Get the next position for this action. + + Parameters + ---------- + agent_position + agent_direction + action + rail + + + Returns + ------- + Tuple[int,int,int] + row, column, direction + """ + possible_transitions = rail.get_transitions(*agent_position, agent_direction) + num_transitions = np.count_nonzero(possible_transitions) + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right], relative to the current orientation + # If only one transition is possible, the forward branch is aligned with it. + if rail.is_dead_end(agent_position): + valid_action = RailEnvActions.MOVE_FORWARD + exit_direction = (agent_direction + 2) % 4 + if possible_transitions[exit_direction]: + new_position = get_new_position(agent_position, exit_direction) + if valid_action == action: + return new_position, exit_direction + elif num_transitions == 1: + valid_action = RailEnvActions.MOVE_FORWARD + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + new_position = get_new_position(agent_position, new_direction) + if valid_action == action: + return new_position, new_direction + else: + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + if new_direction == agent_direction: + valid_action = RailEnvActions.MOVE_FORWARD + if valid_action == action: + new_position = get_new_position(agent_position, new_direction) + return new_position, new_direction + elif new_direction == (agent_direction + 1) % 4: + valid_action = RailEnvActions.MOVE_RIGHT + if valid_action == action: + new_position = get_new_position(agent_position, new_direction) + return new_position, new_direction + elif new_direction == (agent_direction - 1) % 4: + valid_action = RailEnvActions.MOVE_LEFT + if valid_action == action: + new_position = get_new_position(agent_position, new_direction) + return new_position, new_direction + + +def get_action_for_move( + agent_position: Tuple[int, int], + agent_direction: Grid4TransitionsEnum, + next_agent_position: Tuple[int, int], + next_agent_direction: int, + rail: GridTransitionMap) -> Optional[RailEnvActions]: + """ + Get the action (if any) to move from a position and direction to another. + + Parameters + ---------- + agent_position + agent_direction + next_agent_position + next_agent_direction + rail + + + Returns + ------- + Optional[RailEnvActions] + the action (if direct transition possible) or None. + """ + possible_transitions = rail.get_transitions(*agent_position, agent_direction) + num_transitions = np.count_nonzero(possible_transitions) + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right], relative to the current orientation + # If only one transition is possible, the forward branch is aligned with it. + if rail.is_dead_end(agent_position): + valid_action = RailEnvActions.MOVE_FORWARD + new_direction = (agent_direction + 2) % 4 + if possible_transitions[new_direction]: + new_position = get_new_position(agent_position, new_direction) + if new_position == next_agent_position and new_direction == next_agent_direction: + return valid_action + elif num_transitions == 1: + valid_action = RailEnvActions.MOVE_FORWARD + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + new_position = get_new_position(agent_position, new_direction) + if new_position == next_agent_position and new_direction == next_agent_direction: + return valid_action + else: + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + if new_direction == agent_direction: + valid_action = RailEnvActions.MOVE_FORWARD + new_position = get_new_position(agent_position, new_direction) + if new_position == next_agent_position and new_direction == next_agent_direction: + return valid_action + elif new_direction == (agent_direction + 1) % 4: + valid_action = RailEnvActions.MOVE_RIGHT + new_position = get_new_position(agent_position, new_direction) + if new_position == next_agent_position and new_direction == next_agent_direction: + return valid_action + elif new_direction == (agent_direction - 1) % 4: + valid_action = RailEnvActions.MOVE_LEFT + new_position = get_new_position(agent_position, new_direction) + if new_position == next_agent_position and new_direction == next_agent_direction: + return valid_action + + # N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!) def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = None, agent_handle: Optional[int] = None) \ -> Dict[int, Optional[List[WalkingElement]]]: @@ -157,6 +279,7 @@ def get_k_shortest_paths(env: RailEnv, Computes the k shortest paths using modified Dijkstra following pseudo-code https://en.wikipedia.org/wiki/K_shortest_path_routing In contrast to the pseudo-code in wikipedia, we do not a allow for loopy paths. + We add the next_action_element Parameters ---------- @@ -173,7 +296,6 @@ def get_k_shortest_paths(env: RailEnv, ------- List[Tuple[WalkingElement]] We use tuples since we need the path elements to be hashable. - The walking elements do not contain any actions. We use a list of paths in order to keep the order of length. """ @@ -215,7 +337,8 @@ def get_k_shortest_paths(env: RailEnv, # – if u = t then P = P U {Pu} if u.position == target_position: - print(" found of length {} {}".format(len(pu), pu)) + if debug: + print(" found of length {} {}".format(len(pu), pu)) shortest_paths.append(pu) # – if countu ≤ K then @@ -242,9 +365,23 @@ def get_k_shortest_paths(env: RailEnv, pv = pu + (v,) # – insert Pv into B heap.add(pv) + # add actions to shortest paths + shortest_paths_with_action = [] + for p in shortest_paths: + p_with_action = tuple(WalkingElement(position=el.position, + direction=el.direction, + next_action_element=RailEnvNextAction( + action=int(get_action_for_move(el.position, + el.direction, + p[i + 1].position, + p[i + 1].direction, + env.rail)), + next_position=p[i + 1].position, + next_direction=p[i + 1].direction)) for i, el in enumerate(p[:-1])) + shortest_paths_with_action.append(p_with_action) # return P - return shortest_paths + return shortest_paths_with_action def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0): diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index 395f05e885cdb80198df47d2d27e241099a0e4a8..d69e203e94e4ef8a7b68c0c7fe44a4c1f9fde42b 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -395,49 +395,83 @@ def test_get_k_shortest_paths(rendering=False): renderer.render_env(show=True, show_observations=False) input() - actual = get_k_shortest_paths( + actual = set(get_k_shortest_paths( env=env, source_position=initial_position, # west dead-end source_direction=int(initial_direction), # east target_position=target_position, k=10 - ) + )) - expected = [ - (WalkingElement(position=(3, 1), direction=3, next_action_element=None), - WalkingElement(position=(3, 0), direction=3, next_action_element=None), - WalkingElement(position=(3, 1), direction=1, next_action_element=None), - WalkingElement(position=(3, 2), direction=1, next_action_element=None), - WalkingElement(position=(3, 3), direction=1, next_action_element=None), - WalkingElement(position=(3, 4), direction=1, next_action_element=None), - WalkingElement(position=(3, 5), direction=1, next_action_element=None), - WalkingElement(position=(3, 6), direction=1, next_action_element=None), - WalkingElement(position=(4, 6), direction=2, next_action_element=None), - WalkingElement(position=(5, 6), direction=2, next_action_element=None), - WalkingElement(position=(6, 6), direction=2, next_action_element=None), - WalkingElement(position=(5, 6), direction=0, next_action_element=None), - WalkingElement(position=(4, 6), direction=0, next_action_element=None), - WalkingElement(position=(4, 7), direction=1, next_action_element=None), - WalkingElement(position=(4, 8), direction=1, next_action_element=None), - WalkingElement(position=(4, 9), direction=1, next_action_element=None), - WalkingElement(position=(3, 9), direction=0, next_action_element=None)), - (WalkingElement(position=(3, 1), direction=3, next_action_element=None), - WalkingElement(position=(3, 0), direction=3, next_action_element=None), - WalkingElement(position=(3, 1), direction=1, next_action_element=None), - WalkingElement(position=(3, 2), direction=1, next_action_element=None), - WalkingElement(position=(3, 3), direction=1, next_action_element=None), - WalkingElement(position=(2, 3), direction=0, next_action_element=None), - WalkingElement(position=(1, 3), direction=0, next_action_element=None), - WalkingElement(position=(0, 3), direction=0, next_action_element=None), - WalkingElement(position=(0, 4), direction=1, next_action_element=None), - WalkingElement(position=(0, 5), direction=1, next_action_element=None), - WalkingElement(position=(0, 6), direction=1, next_action_element=None), - WalkingElement(position=(0, 7), direction=1, next_action_element=None), - WalkingElement(position=(0, 8), direction=1, next_action_element=None), - WalkingElement(position=(0, 9), direction=1, next_action_element=None), - WalkingElement(position=(1, 9), direction=2, next_action_element=None), - WalkingElement(position=(2, 9), direction=2, next_action_element=None), - WalkingElement(position=(3, 9), direction=2, next_action_element=None)) - ] + expected = set([ + ( + WalkingElement(position=(3, 1), direction=3, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 0), next_direction=3)), + WalkingElement(position=(3, 0), direction=3, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 1), next_direction=1)), + WalkingElement(position=(3, 1), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 2), next_direction=1)), + WalkingElement(position=(3, 2), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 3), next_direction=1)), + WalkingElement(position=(3, 3), direction=1, + next_action_element=RailEnvNextAction(action=1, next_position=(2, 3), next_direction=0)), + WalkingElement(position=(2, 3), direction=0, + next_action_element=RailEnvNextAction(action=2, next_position=(1, 3), next_direction=0)), + WalkingElement(position=(1, 3), direction=0, + next_action_element=RailEnvNextAction(action=2, next_position=(0, 3), next_direction=0)), + WalkingElement(position=(0, 3), direction=0, + next_action_element=RailEnvNextAction(action=2, next_position=(0, 4), next_direction=1)), + WalkingElement(position=(0, 4), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(0, 5), next_direction=1)), + WalkingElement(position=(0, 5), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(0, 6), next_direction=1)), + WalkingElement(position=(0, 6), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(0, 7), next_direction=1)), + WalkingElement(position=(0, 7), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(0, 8), next_direction=1)), + WalkingElement(position=(0, 8), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(0, 9), next_direction=1)), + WalkingElement(position=(0, 9), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(1, 9), next_direction=2)), + WalkingElement(position=(1, 9), direction=2, + next_action_element=RailEnvNextAction(action=2, next_position=(2, 9), next_direction=2)), + WalkingElement(position=(2, 9), direction=2, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 9), next_direction=2)) + ), + ( + WalkingElement(position=(3, 1), direction=3, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 0), next_direction=3)), + WalkingElement(position=(3, 0), direction=3, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 1), next_direction=1)), + WalkingElement(position=(3, 1), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 2), next_direction=1)), + WalkingElement(position=(3, 2), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 3), next_direction=1)), + WalkingElement(position=(3, 3), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 4), next_direction=1)), + WalkingElement(position=(3, 4), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 5), next_direction=1)), + WalkingElement(position=(3, 5), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 6), next_direction=1)), + WalkingElement(position=(3, 6), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(4, 6), next_direction=2)), + WalkingElement(position=(4, 6), direction=2, + next_action_element=RailEnvNextAction(action=2, next_position=(5, 6), next_direction=2)), + WalkingElement(position=(5, 6), direction=2, + next_action_element=RailEnvNextAction(action=2, next_position=(6, 6), next_direction=2)), + WalkingElement(position=(6, 6), direction=2, + next_action_element=RailEnvNextAction(action=2, next_position=(5, 6), next_direction=0)), + WalkingElement(position=(5, 6), direction=0, + next_action_element=RailEnvNextAction(action=2, next_position=(4, 6), next_direction=0)), + WalkingElement(position=(4, 6), direction=0, + next_action_element=RailEnvNextAction(action=3, next_position=(4, 7), next_direction=1)), + WalkingElement(position=(4, 7), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(4, 8), next_direction=1)), + WalkingElement(position=(4, 8), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(4, 9), next_direction=1)), + WalkingElement(position=(4, 9), direction=1, + next_action_element=RailEnvNextAction(action=2, next_position=(3, 9), next_direction=0)) + ) + ]) assert actual == expected, "actual={},expected={}".format(actual, expected)