From 8137be647e3447af0194aedab84f5e4aed5913fe Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Tue, 17 Sep 2019 12:10:43 +0200 Subject: [PATCH] Refactoring: move distance_map to separate class --- docs/intro_observationbuilder.rst | 2 +- examples/custom_observation_example.py | 2 +- examples/debugging_example_DELETE.py | 2 +- flatland/core/env.py | 2 - flatland/envs/distance_map.py | 124 +++++++++++++++++++++++ flatland/envs/observations.py | 8 +- flatland/envs/predictions.py | 2 +- flatland/envs/rail_env.py | 105 ++----------------- tests/test_distance_map.py | 8 +- tests/test_flatland_envs_observations.py | 2 +- tests/test_flatland_envs_predictions.py | 2 +- tests/test_flatland_malfunction.py | 2 +- tests/tests_generators.py | 14 +-- 13 files changed, 155 insertions(+), 120 deletions(-) create mode 100644 flatland/envs/distance_map.py diff --git a/docs/intro_observationbuilder.rst b/docs/intro_observationbuilder.rst index 64e953da..4386f9e0 100644 --- a/docs/intro_observationbuilder.rst +++ b/docs/intro_observationbuilder.rst @@ -110,7 +110,7 @@ Note that this simple strategy fails when multiple agents are present, as each a for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: new_position = self._new_position(agent.position, direction) - min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) + min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 18e96a2b..25238d42 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -81,7 +81,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: new_position = get_new_position(agent.position, direction) - min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) + min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index 1f0f89de..8aef94c2 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -49,7 +49,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: new_position = get_new_position(agent.position, direction) - min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) + min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) diff --git a/flatland/core/env.py b/flatland/core/env.py index f1f1b270..1bc5b6f3 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -45,8 +45,6 @@ class Environment: def __init__(self): self.action_space = () self.observation_space = () - self.distance_map_computed = False - self.distance_map = None pass def reset(self): diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py new file mode 100644 index 00000000..27820206 --- /dev/null +++ b/flatland/envs/distance_map.py @@ -0,0 +1,124 @@ +from collections import deque +from typing import List + +import numpy as np + +from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_utils import EnvAgent + + +class DistanceMap: + def __init__(self, agents: List[EnvAgent], env_height: int, env_width: int): + self.env_height = env_height + self.env_width = env_width + self.distance_map = np.inf * np.ones(shape=(len(agents), + self.env_height, + self.env_width, + 4)) + self.distance_map_computed = False + + """ + Set the distance map + """ + def set(self, distance_map: np.array): + self.distance_map = distance_map + + """ + Get the distance map + """ + def get(self) -> np.array: + return self.distance_map + + """ + Compute the distance map + """ + def compute(self, agents: List[EnvAgent], rail: GridTransitionMap): + self.distance_map_computed = True + self.distance_map = np.inf * np.ones(shape=(len(agents), + self.env_height, + self.env_width, + 4)) + for i, agent in enumerate(agents): + self._distance_map_walker(rail, agent.target, i) + + def _distance_map_walker(self, rail: GridTransitionMap, position, target_nr: int): + """ + Utility function to compute distance maps from each cell in the rail network (and each possible + orientation within it) to each agent's target cell. + """ + # Returns max distance to target, from the farthest away node, while filling in distance_map + self.distance_map[target_nr, position[0], position[1], :] = 0 + + # Fill in the (up to) 4 neighboring nodes + # direction is the direction of movement, meaning that at least a possible orientation of an agent + # in cell (row,col) allows a movement in direction `direction' + nodes_queue = deque(self._get_and_update_neighbors(rail, position, target_nr, 0, enforce_target_direction=-1)) + + # BFS from target `position' to all the reachable nodes in the grid + # Stop the search if the target position is re-visited, in any direction + visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), + (position[0], position[1], 3)} + + max_distance = 0 + + while nodes_queue: + node = nodes_queue.popleft() + + node_id = (node[0], node[1], node[2]) + + if node_id not in visited: + visited.add(node_id) + + # From the list of possible neighbors that have at least a path to the current node, only keep those + # whose new orientation in the current cell would allow a transition to direction node[2] + valid_neighbors = self._get_and_update_neighbors(rail, (node[0], node[1]), target_nr, node[3], node[2]) + + for n in valid_neighbors: + nodes_queue.append(n) + + if len(valid_neighbors) > 0: + max_distance = max(max_distance, node[3] + 1) + + return max_distance + + def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance, enforce_target_direction=-1): + """ + Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the + minimum distances from each target cell. + """ + neighbors = [] + + possible_directions = [0, 1, 2, 3] + if enforce_target_direction >= 0: + # The agent must land into the current cell with orientation `enforce_target_direction'. + # This is only possible if the agent has arrived from the cell in the opposite direction! + possible_directions = [(enforce_target_direction + 2) % 4] + + for neigh_direction in possible_directions: + new_cell = get_new_position(position, neigh_direction) + + if new_cell[0] >= 0 and new_cell[0] < self.env_height and new_cell[1] >= 0 and new_cell[1] < self.env_width: + + desired_movement_from_new_cell = (neigh_direction + 2) % 4 + + # Check all possible transitions in new_cell + for agent_orientation in range(4): + # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? + is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation), + desired_movement_from_new_cell) + # is_valid = True + + if is_valid: + """ + # TODO: check that it works with deadends! -- still bugged! + movement = desired_movement_from_new_cell + if isNextCellDeadEnd: + movement = (desired_movement_from_new_cell+2) % 4 + """ + new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation], + current_distance + 1) + neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance)) + self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance + + return neighbors diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 3f398c74..ce0ce9b1 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -50,7 +50,7 @@ class TreeObsForRailEnv(ObservationBuilder): if agents[i].target != self.agents_previous_reset[i].target: compute_distance_map = True # Don't compute the distance map if it was loaded - if self.agents_previous_reset is None and self.env.distance_map is not None: + if self.agents_previous_reset is None and self.env.distance_map.get() is not None: self.location_has_target = {tuple(agent.target): 1 for agent in agents} compute_distance_map = False @@ -167,7 +167,7 @@ class TreeObsForRailEnv(ObservationBuilder): # Root node - current position # Here information about the agent itself is stored - observation = [0, 0, 0, 0, 0, 0, self.env.distance_map[(handle, *agent.position, agent.direction)], 0, 0, + observation = [0, 0, 0, 0, 0, 0, self.env.distance_map.get()[(handle, *agent.position, agent.direction)], 0, 0, agent.malfunction_data['malfunction'], agent.speed_data['speed']] visited = set() @@ -397,7 +397,7 @@ class TreeObsForRailEnv(ObservationBuilder): potential_conflict, unusable_switch, np.inf, - self.env.distance_map[handle, position[0], position[1], direction], + self.env.distance_map.get()[handle, position[0], position[1], direction], other_agent_same_direction, other_agent_opposite_direction, malfunctioning_agent, @@ -411,7 +411,7 @@ class TreeObsForRailEnv(ObservationBuilder): potential_conflict, unusable_switch, tot_dist, - self.env.distance_map[handle, position[0], position[1], direction], + self.env.distance_map.get()[handle, position[0], position[1], direction], other_agent_same_direction, other_agent_opposite_direction, malfunctioning_agent, diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index dc5a13b3..c77b5787 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -148,7 +148,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): for direction in range(4): if cell_transitions[direction] == 1: neighbour_cell = get_new_position(agent.position, direction) - target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction] + target_dist = distance_map.get()[agent.handle, neighbour_cell[0], neighbour_cell[1], direction] if target_dist < min_dist or no_dist_found: min_dist = target_dist new_direction = direction diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 8265dffa..0dfd4535 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -5,7 +5,6 @@ Definition of the RailEnv environment. import warnings from enum import IntEnum from typing import List -from collections import deque import msgpack import msgpack_numpy as m @@ -15,6 +14,7 @@ from flatland.core.env import Environment from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent +from flatland.envs.distance_map import DistanceMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator @@ -171,6 +171,7 @@ class RailEnv(Environment): self.agents: List[EnvAgent] = [None] * number_of_agents # live agents self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents # static agent information self.num_resets = 0 + self.distance_map = DistanceMap(self.agents, self.height, self.width) self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? @@ -235,7 +236,7 @@ class RailEnv(Environment): rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) if optionals and 'distance_map' in optionals: - self.distance_map = optionals['distance_map'] + self.distance_map.set(optionals['distance_map']) if regen_rail or self.rail is None: self.rail = rail @@ -576,7 +577,7 @@ class RailEnv(Environment): self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]] if "distance_map" in data.keys(): - self.distance_map = data["distance_map"] + self.distance_map.set(data["distance_map"]) # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -590,7 +591,7 @@ class RailEnv(Environment): msgpack.packb(grid_data, use_bin_type=True) msgpack.packb(agent_data, use_bin_type=True) msgpack.packb(agent_static_data, use_bin_type=True) - distance_map_data = self.distance_map + distance_map_data = self.distance_map.get() msgpack.packb(distance_map_data, use_bin_type=True) msg_data = { "grid": grid_data, @@ -601,8 +602,8 @@ class RailEnv(Environment): return msgpack.packb(msg_data, use_bin_type=True) def save(self, filename): - if self.distance_map is not None: - if len(self.distance_map) > 0: + if self.distance_map.get() is not None: + if len(self.distance_map.get()) > 0: with open(filename, "wb") as file_out: file_out.write(self.get_full_state_dist_msg()) else: @@ -626,95 +627,7 @@ class RailEnv(Environment): self.set_full_state_msg(load_data) def compute_distance_map(self): - agents = self.agents - # For testing only --> To assert if a distance map need to be recomputed. - self.distance_map_computed = True - nb_agents = len(agents) - self.distance_map = np.inf * np.ones(shape=(nb_agents, - self.height, - self.width, - 4)) - max_dist = np.zeros(nb_agents) - max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] + self.distance_map.compute(self.agents, self.rail) # Update local lookup table for all agents' target locations - self.obs_builder.location_has_target = {tuple(agent.target): 1 for agent in agents} + self.obs_builder.location_has_target = {tuple(agent.target): 1 for agent in self.agents} - def _distance_map_walker(self, position, target_nr): - """ - Utility function to compute distance maps from each cell in the rail network (and each possible - orientation within it) to each agent's target cell. - """ - # Returns max distance to target, from the farthest away node, while filling in distance_map - self.distance_map[target_nr, position[0], position[1], :] = 0 - - # Fill in the (up to) 4 neighboring nodes - # direction is the direction of movement, meaning that at least a possible orientation of an agent - # in cell (row,col) allows a movement in direction `direction' - nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) - - # BFS from target `position' to all the reachable nodes in the grid - # Stop the search if the target position is re-visited, in any direction - visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), - (position[0], position[1], 3)} - - max_distance = 0 - - while nodes_queue: - node = nodes_queue.popleft() - - node_id = (node[0], node[1], node[2]) - - if node_id not in visited: - visited.add(node_id) - - # From the list of possible neighbors that have at least a path to the current node, only keep those - # whose new orientation in the current cell would allow a transition to direction node[2] - valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2]) - - for n in valid_neighbors: - nodes_queue.append(n) - - if len(valid_neighbors) > 0: - max_distance = max(max_distance, node[3] + 1) - - return max_distance - - def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1): - """ - Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the - minimum distances from each target cell. - """ - neighbors = [] - - possible_directions = [0, 1, 2, 3] - if enforce_target_direction >= 0: - # The agent must land into the current cell with orientation `enforce_target_direction'. - # This is only possible if the agent has arrived from the cell in the opposite direction! - possible_directions = [(enforce_target_direction + 2) % 4] - - for neigh_direction in possible_directions: - new_cell = get_new_position(position, neigh_direction) - - if new_cell[0] >= 0 and new_cell[0] < self.height and new_cell[1] >= 0 and new_cell[1] < self.width: - - desired_movement_from_new_cell = (neigh_direction + 2) % 4 - - # Check all possible transitions in new_cell - for agent_orientation in range(4): - # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? - is_valid = self.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), - desired_movement_from_new_cell) - - if is_valid: - """ - # TODO: check that it works with deadends! -- still bugged! - movement = desired_movement_from_new_cell - if isNextCellDeadEnd: - movement = (desired_movement_from_new_cell+2) % 4 - """ - new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation], - current_distance + 1) - neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance)) - self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance - - return neighbors diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 2653aaec..3bed89b8 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -44,7 +44,7 @@ def test_walker(): # reset to set agents from agents_static env.reset(False, False) - print(env.distance_map[(0, *[0, 1], 1)]) - assert env.distance_map[(0, *[0, 1], 1)] == 3 - print(env.distance_map[(0, *[0, 2], 3)]) - assert env.distance_map[(0, *[0, 2], 1)] == 2 + print(env.distance_map.get()[(0, *[0, 1], 1)]) + assert env.distance_map.get()[(0, *[0, 1], 1)] == 3 + print(env.distance_map.get()[(0, *[0, 2], 3)]) + assert env.distance_map.get()[(0, *[0, 2], 1)] == 2 diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 5d35bfb1..46000de4 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -63,7 +63,7 @@ def _step_along_shortest_path(env, obs_builder, rail): is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation), desired_movement_from_new_cell) if is_valid: - distance_to_target = obs_builder.env.distance_map[ + distance_to_target = obs_builder.env.distance_map.get()[ (agent.handle, *agent.position, exit_direction)] print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position, agent.direction, diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 0221bf6d..c3149467 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -137,7 +137,7 @@ def test_shortest_path_predictor(rendering=False): input("Continue?") # compute the observations and predictions - distance_map = env.distance_map + distance_map = env.distance_map.get() assert distance_map[0, agent.position[0], agent.position[ 1], agent.direction] == 5.0, "found {} instead of {}".format( distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index d1f487ce..81b61381 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -42,7 +42,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: new_position = get_new_position(agent.position, direction) - min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) + min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 43e6e720..8918a585 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -129,7 +129,7 @@ def tests_rail_from_file(): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.save(file_name) - dist_map_shape = np.shape(env.distance_map) + dist_map_shape = np.shape(env.distance_map.get()) # initialize agents_static rails_initial = env.rail.grid agents_initial = env.agents @@ -148,9 +148,9 @@ def tests_rail_from_file(): assert agents_initial == agents_loaded # Check that distance map was not recomputed - assert env.distance_map_computed is False - assert np.shape(env.distance_map) == dist_map_shape - assert env.distance_map is not None + assert env.distance_map.distance_map_computed is False + assert np.shape(env.distance_map.get()) == dist_map_shape + assert env.distance_map.get() is not None # Test to save and load file without distance map. @@ -222,6 +222,6 @@ def tests_rail_from_file(): assert agents_initial_2 == agents_loaded_4 # Check that distance map was generated with correct shape - assert env4.distance_map_computed is True - assert env4.distance_map is not None - assert np.shape(env4.distance_map) == dist_map_shape + assert env4.distance_map.distance_map_computed is True + assert env4.distance_map.get() is not None + assert np.shape(env4.distance_map.get()) == dist_map_shape -- GitLab