diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py index c9c6b00375ef4577880e2b8c98c2ff9dc946a7fa..52ab5fadeab17c153e5bba0414d10a2c593ad338 100644 --- a/flatland/envs/distance_map.py +++ b/flatland/envs/distance_map.py @@ -21,12 +21,14 @@ class DistanceMap: """ Set the distance map """ + def set(self, distance_map: np.ndarray): self.distance_map = distance_map """ Get the distance map """ + def get(self) -> np.ndarray: if self.reset_was_called: @@ -54,9 +56,10 @@ class DistanceMap: """ Reset the distance map """ + def reset(self, agents: List[EnvAgent], rail: GridTransitionMap): self.reset_was_called = True - self.agents = agents + self.agents: List[EnvAgent] = agents self.rail = rail self.env_height = rail.height self.env_width = rail.width @@ -110,7 +113,8 @@ class DistanceMap: return max_distance - def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance, enforce_target_direction=-1): + def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance, + enforce_target_direction=-1): """ Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the minimum distances from each target cell. @@ -134,8 +138,7 @@ class DistanceMap: for agent_orientation in range(4): # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation), - desired_movement_from_new_cell) - # is_valid = True + desired_movement_from_new_cell) if is_valid: """ diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 1805e8c6b01a9fb88db082dc0e7de7909800c0b6..f8c54f33088a97f989cd10a120ed848bbf9bf4ae 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -4,7 +4,7 @@ Definition of the RailEnv environment. # TODO: _ this is a global method --> utils or remove later import warnings from enum import IntEnum -from typing import List, Set, NamedTuple, Optional, Tuple, Dict +from typing import List, NamedTuple, Optional, Tuple, Dict import msgpack import msgpack_numpy as m @@ -20,7 +20,6 @@ from flatland.envs.distance_map import DistanceMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator -from flatland.utils.ordered_set import OrderedSet m.patch() @@ -587,60 +586,6 @@ class RailEnv(Environment): transition_valid = True return new_direction, transition_valid - @staticmethod - def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, - agent_position: Tuple[int, int], - rail: GridTransitionMap) -> Set[RailEnvNextAction]: - """ - Get the valid move actions (forward, left, right) for an agent. - - Parameters - ---------- - agent_direction : Grid4TransitionsEnum - agent_position: Tuple[int,int] - rail : GridTransitionMap - - - Returns - ------- - Set of `RailEnvNextAction` (tuples of (action,position,direction)) - Possible move actions (forward,left,right) and the next position/direction they lead to. - It is not checked that the next cell is free. - """ - valid_actions: Set[RailEnvNextAction] = OrderedSet() - possible_transitions = rail.get_transitions(*agent_position, agent_direction) - num_transitions = np.count_nonzero(possible_transitions) - # Start from the current orientation, and see which transitions are available; - # organize them as [left, forward, right], relative to the current orientation - # If only one transition is possible, the forward branch is aligned with it. - if rail.is_dead_end(agent_position): - action = RailEnvActions.MOVE_FORWARD - exit_direction = (agent_direction + 2) % 4 - if possible_transitions[exit_direction]: - new_position = get_new_position(agent_position, exit_direction) - valid_actions.add(RailEnvNextAction(action, new_position, exit_direction)) - elif num_transitions == 1: - action = RailEnvActions.MOVE_FORWARD - for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: - if possible_transitions[new_direction]: - new_position = get_new_position(agent_position, new_direction) - valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) - else: - for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: - if possible_transitions[new_direction]: - if new_direction == agent_direction: - action = RailEnvActions.MOVE_FORWARD - elif new_direction == (agent_direction + 1) % 4: - action = RailEnvActions.MOVE_RIGHT - elif new_direction == (agent_direction - 1) % 4: - action = RailEnvActions.MOVE_LEFT - else: - raise Exception("Illegal state") - - new_position = get_new_position(agent_position, new_direction) - valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) - return valid_actions - def _get_observations(self): self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index dc1cff12c0c8b1860859208a13d6403734a2d2ad..200c19dc21de48559dc1e641ab01eb244208f5aa 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -1,8 +1,18 @@ +import math +from typing import Tuple, Set, Dict, List + +import numpy as np + +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.distance_map import DistanceMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_env import RailEnv, RailEnvNextAction, RailEnvActions from flatland.envs.rail_generators import rail_from_file from flatland.envs.schedule_generators import schedule_from_file +from flatland.utils.ordered_set import OrderedSet def load_flatland_environment_from_file(file_name, load_from_package=None, obs_builder_object=None): @@ -17,3 +27,78 @@ def load_flatland_environment_from_file(file_name, load_from_package=None, obs_b schedule_generator=schedule_from_file(file_name, load_from_package), obs_builder_object=obs_builder_object) return environment + + +def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, + agent_position: Tuple[int, int], + rail: GridTransitionMap) -> Set[RailEnvNextAction]: + """ + Get the valid move actions (forward, left, right) for an agent. + + Parameters + ---------- + agent_direction : Grid4TransitionsEnum + agent_position: Tuple[int,int] + rail : GridTransitionMap + + + Returns + ------- + Set of `RailEnvNextAction` (tuples of (action,position,direction)) + Possible move actions (forward,left,right) and the next position/direction they lead to. + It is not checked that the next cell is free. + """ + valid_actions: Set[RailEnvNextAction] = OrderedSet() + possible_transitions = rail.get_transitions(*agent_position, agent_direction) + num_transitions = np.count_nonzero(possible_transitions) + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right], relative to the current orientation + # If only one transition is possible, the forward branch is aligned with it. + if rail.is_dead_end(agent_position): + action = RailEnvActions.MOVE_FORWARD + exit_direction = (agent_direction + 2) % 4 + if possible_transitions[exit_direction]: + new_position = get_new_position(agent_position, exit_direction) + valid_actions.add(RailEnvNextAction(action, new_position, exit_direction)) + elif num_transitions == 1: + action = RailEnvActions.MOVE_FORWARD + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + new_position = get_new_position(agent_position, new_direction) + valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) + else: + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + if new_direction == agent_direction: + action = RailEnvActions.MOVE_FORWARD + elif new_direction == (agent_direction + 1) % 4: + action = RailEnvActions.MOVE_RIGHT + elif new_direction == (agent_direction - 1) % 4: + action = RailEnvActions.MOVE_LEFT + else: + raise Exception("Illegal state") + + new_position = get_new_position(agent_position, new_direction) + valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) + return valid_actions + + +def get_shorts_paths(distance_map: DistanceMap) -> Dict[int, List[RailEnvNextAction]]: + shortest_paths = dict() + for a in distance_map.agents: + position = a.position + direction = a.direction + shortest_paths[a.handle] = [] + + while (position != a.target): + next_actions = get_valid_move_actions_(direction, position, distance_map.rail) + best = math.inf + + best_next_action = None + for next_action in next_actions: + if distance_map.get()[a.handle, position[0], position[1], direction] < best: + best_next_action = next_action + position = best_next_action.next_position + direction = best_next_action.next_direction + shortest_paths[a.handle].append(best_next_action) + return shortest_paths diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index c31494673e63a17dc07eb6d89eeb581c640b1e13..b2a9b612e923e5a4290b938e2ddd723f03b3c1b1 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -7,7 +7,8 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction +from flatland.envs.rail_env_utils import get_shorts_paths from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool @@ -142,6 +143,14 @@ def test_shortest_path_predictor(rendering=False): 1], agent.direction] == 5.0, "found {} instead of {}".format( distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0) + paths = get_shorts_paths(env.distance_map)[0] + assert paths == [ + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), next_direction=0), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), next_direction=0), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8), next_direction=1), + RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9), next_direction=1)] + # extract the data predictions = env.obs_builder.predictions positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))