Commit 8137be64 authored by u229589's avatar u229589
Browse files

Refactoring: move distance_map to separate class

parent 53dd3c55
Pipeline #2033 passed with stages
in 62 minutes and 42 seconds
......@@ -110,7 +110,7 @@ Note that this simple strategy fails when multiple agents are present, as each a
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = self._new_position(agent.position, direction)
min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
......
......@@ -81,7 +81,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(agent.position, direction)
min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
......
......@@ -49,7 +49,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(agent.position, direction)
min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
......
......@@ -45,8 +45,6 @@ class Environment:
def __init__(self):
self.action_space = ()
self.observation_space = ()
self.distance_map_computed = False
self.distance_map = None
pass
def reset(self):
......
from collections import deque
from typing import List
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
class DistanceMap:
def __init__(self, agents: List[EnvAgent], env_height: int, env_width: int):
self.env_height = env_height
self.env_width = env_width
self.distance_map = np.inf * np.ones(shape=(len(agents),
self.env_height,
self.env_width,
4))
self.distance_map_computed = False
"""
Set the distance map
"""
def set(self, distance_map: np.array):
self.distance_map = distance_map
"""
Get the distance map
"""
def get(self) -> np.array:
return self.distance_map
"""
Compute the distance map
"""
def compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
self.distance_map_computed = True
self.distance_map = np.inf * np.ones(shape=(len(agents),
self.env_height,
self.env_width,
4))
for i, agent in enumerate(agents):
self._distance_map_walker(rail, agent.target, i)
def _distance_map_walker(self, rail: GridTransitionMap, position, target_nr: int):
"""
Utility function to compute distance maps from each cell in the rail network (and each possible
orientation within it) to each agent's target cell.
"""
# Returns max distance to target, from the farthest away node, while filling in distance_map
self.distance_map[target_nr, position[0], position[1], :] = 0
# Fill in the (up to) 4 neighboring nodes
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(rail, position, target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
(position[0], position[1], 3)}
max_distance = 0
while nodes_queue:
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
if node_id not in visited:
visited.add(node_id)
# From the list of possible neighbors that have at least a path to the current node, only keep those
# whose new orientation in the current cell would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors(rail, (node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors) > 0:
max_distance = max(max_distance, node[3] + 1)
return max_distance
def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance, enforce_target_direction=-1):
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
"""
neighbors = []
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction + 2) % 4]
for neigh_direction in possible_directions:
new_cell = get_new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env_height and new_cell[1] >= 0 and new_cell[1] < self.env_width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
# is_valid = True
if is_valid:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
if isNextCellDeadEnd:
movement = (desired_movement_from_new_cell+2) % 4
"""
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
current_distance + 1)
neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
return neighbors
......@@ -50,7 +50,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if agents[i].target != self.agents_previous_reset[i].target:
compute_distance_map = True
# Don't compute the distance map if it was loaded
if self.agents_previous_reset is None and self.env.distance_map is not None:
if self.agents_previous_reset is None and self.env.distance_map.get() is not None:
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
compute_distance_map = False
......@@ -167,7 +167,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Root node - current position
# Here information about the agent itself is stored
observation = [0, 0, 0, 0, 0, 0, self.env.distance_map[(handle, *agent.position, agent.direction)], 0, 0,
observation = [0, 0, 0, 0, 0, 0, self.env.distance_map.get()[(handle, *agent.position, agent.direction)], 0, 0,
agent.malfunction_data['malfunction'], agent.speed_data['speed']]
visited = set()
......@@ -397,7 +397,7 @@ class TreeObsForRailEnv(ObservationBuilder):
potential_conflict,
unusable_switch,
np.inf,
self.env.distance_map[handle, position[0], position[1], direction],
self.env.distance_map.get()[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
......@@ -411,7 +411,7 @@ class TreeObsForRailEnv(ObservationBuilder):
potential_conflict,
unusable_switch,
tot_dist,
self.env.distance_map[handle, position[0], position[1], direction],
self.env.distance_map.get()[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
......
......@@ -148,7 +148,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
for direction in range(4):
if cell_transitions[direction] == 1:
neighbour_cell = get_new_position(agent.position, direction)
target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
target_dist = distance_map.get()[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
if target_dist < min_dist or no_dist_found:
min_dist = target_dist
new_direction = direction
......
......@@ -5,7 +5,6 @@ Definition of the RailEnv environment.
import warnings
from enum import IntEnum
from typing import List
from collections import deque
import msgpack
import msgpack_numpy as m
......@@ -15,6 +14,7 @@ from flatland.core.env import Environment
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from flatland.envs.distance_map import DistanceMap
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
......@@ -171,6 +171,7 @@ class RailEnv(Environment):
self.agents: List[EnvAgent] = [None] * number_of_agents # live agents
self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents # static agent information
self.num_resets = 0
self.distance_map = DistanceMap(self.agents, self.height, self.width)
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
......@@ -235,7 +236,7 @@ class RailEnv(Environment):
rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
if optionals and 'distance_map' in optionals:
self.distance_map = optionals['distance_map']
self.distance_map.set(optionals['distance_map'])
if regen_rail or self.rail is None:
self.rail = rail
......@@ -576,7 +577,7 @@ class RailEnv(Environment):
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
if "distance_map" in data.keys():
self.distance_map = data["distance_map"]
self.distance_map.set(data["distance_map"])
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
......@@ -590,7 +591,7 @@ class RailEnv(Environment):
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
distance_map_data = self.distance_map
distance_map_data = self.distance_map.get()
msgpack.packb(distance_map_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
......@@ -601,8 +602,8 @@ class RailEnv(Environment):
return msgpack.packb(msg_data, use_bin_type=True)
def save(self, filename):
if self.distance_map is not None:
if len(self.distance_map) > 0:
if self.distance_map.get() is not None:
if len(self.distance_map.get()) > 0:
with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_dist_msg())
else:
......@@ -626,95 +627,7 @@ class RailEnv(Environment):
self.set_full_state_msg(load_data)
def compute_distance_map(self):
agents = self.agents
# For testing only --> To assert if a distance map need to be recomputed.
self.distance_map_computed = True
nb_agents = len(agents)
self.distance_map = np.inf * np.ones(shape=(nb_agents,
self.height,
self.width,
4))
max_dist = np.zeros(nb_agents)
max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
self.distance_map.compute(self.agents, self.rail)
# Update local lookup table for all agents' target locations
self.obs_builder.location_has_target = {tuple(agent.target): 1 for agent in agents}
self.obs_builder.location_has_target = {tuple(agent.target): 1 for agent in self.agents}
def _distance_map_walker(self, position, target_nr):
"""
Utility function to compute distance maps from each cell in the rail network (and each possible
orientation within it) to each agent's target cell.
"""
# Returns max distance to target, from the farthest away node, while filling in distance_map
self.distance_map[target_nr, position[0], position[1], :] = 0
# Fill in the (up to) 4 neighboring nodes
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
(position[0], position[1], 3)}
max_distance = 0
while nodes_queue:
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
if node_id not in visited:
visited.add(node_id)
# From the list of possible neighbors that have at least a path to the current node, only keep those
# whose new orientation in the current cell would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors) > 0:
max_distance = max(max_distance, node[3] + 1)
return max_distance
def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
"""
neighbors = []
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction + 2) % 4]
for neigh_direction in possible_directions:
new_cell = get_new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.height and new_cell[1] >= 0 and new_cell[1] < self.width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is_valid = self.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
if isNextCellDeadEnd:
movement = (desired_movement_from_new_cell+2) % 4
"""
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
current_distance + 1)
neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
return neighbors
......@@ -44,7 +44,7 @@ def test_walker():
# reset to set agents from agents_static
env.reset(False, False)
print(env.distance_map[(0, *[0, 1], 1)])
assert env.distance_map[(0, *[0, 1], 1)] == 3
print(env.distance_map[(0, *[0, 2], 3)])
assert env.distance_map[(0, *[0, 2], 1)] == 2
print(env.distance_map.get()[(0, *[0, 1], 1)])
assert env.distance_map.get()[(0, *[0, 1], 1)] == 3
print(env.distance_map.get()[(0, *[0, 2], 3)])
assert env.distance_map.get()[(0, *[0, 2], 1)] == 2
......@@ -63,7 +63,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
distance_to_target = obs_builder.env.distance_map[
distance_to_target = obs_builder.env.distance_map.get()[
(agent.handle, *agent.position, exit_direction)]
print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position,
agent.direction,
......
......@@ -137,7 +137,7 @@ def test_shortest_path_predictor(rendering=False):
input("Continue?")
# compute the observations and predictions
distance_map = env.distance_map
distance_map = env.distance_map.get()
assert distance_map[0, agent.position[0], agent.position[
1], agent.direction] == 5.0, "found {} instead of {}".format(
distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
......
......@@ -42,7 +42,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(agent.position, direction)
min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
......
......@@ -129,7 +129,7 @@ def tests_rail_from_file():
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.save(file_name)
dist_map_shape = np.shape(env.distance_map)
dist_map_shape = np.shape(env.distance_map.get())
# initialize agents_static
rails_initial = env.rail.grid
agents_initial = env.agents
......@@ -148,9 +148,9 @@ def tests_rail_from_file():
assert agents_initial == agents_loaded
# Check that distance map was not recomputed
assert env.distance_map_computed is False
assert np.shape(env.distance_map) == dist_map_shape
assert env.distance_map is not None
assert env.distance_map.distance_map_computed is False
assert np.shape(env.distance_map.get()) == dist_map_shape
assert env.distance_map.get() is not None
# Test to save and load file without distance map.
......@@ -222,6 +222,6 @@ def tests_rail_from_file():
assert agents_initial_2 == agents_loaded_4
# Check that distance map was generated with correct shape
assert env4.distance_map_computed is True
assert env4.distance_map is not None
assert np.shape(env4.distance_map) == dist_map_shape
assert env4.distance_map.distance_map_computed is True
assert env4.distance_map.get() is not None
assert np.shape(env4.distance_map.get()) == dist_map_shape
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment