Commit 8fb117ba authored by u229589's avatar u229589
Browse files

Refactoring: move distance_map from ObservationBuilder to RailEnv

parent 8f8465df
Pipeline #2002 passed with stages
in 32 minutes and 8 seconds
......@@ -63,7 +63,7 @@ cells and orientations to the target of each agent, i.e. a distance map for each
In this example we exploit these distance maps by implementing an observation builder that shows the current shortest path for each agent as a one-hot observation vector of length 3, whose components represent the possible directions an agent can take (LEFT, FORWARD, RIGHT). All values of the observation vector are set to :code:`0` except for the shortest direction where it is set to :code:`1`.
Using this observation with highly engineered features indicating the agent's shortest path, an agent can then learn to take the corresponding action at each time-step; or we could even hardcode the optimal policy.
Using this observation with highly engineered features indicating the agent's shortest path, an agent can then learn to take the corresponding action at each time-step; or we could even hardcode the optimal policy.
Note that this simple strategy fails when multiple agents are present, as each agent would only attempt its greedy solution, which is not usually `Pareto-optimal <https://en.wikipedia.org/wiki/Pareto_efficiency>`_ in this context.
.. _TreeObsForRailEnv: https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
......@@ -71,7 +71,7 @@ Note that this simple strategy fails when multiple agents are present, as each a
.. code-block:: python
from flatland.envs.observations import TreeObsForRailEnv
class SingleAgentNavigationObs(TreeObsForRailEnv):
"""
We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
......@@ -84,7 +84,7 @@ Note that this simple strategy fails when multiple agents are present, as each a
"""
def __init__(self):
super().__init__(max_depth=0)
# We set max_depth=0 in because we only need to look at the current
# We set max_depth=0 in because we only need to look at the current
# position of the agent to decide what direction is shortest.
self.observation_space = [3]
......@@ -110,7 +110,7 @@ Note that this simple strategy fails when multiple agents are present, as each a
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = self._new_position(agent.position, direction)
min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
......@@ -175,28 +175,28 @@ In contrast to the previous examples we also implement the :code:`def get_many(s
.. _example: https://gitlab.aicrowd.com/flatland/flatland/blob/master/examples/custom_observation_example.py#L110
.. code-block:: python
class ObservePredictions(TreeObsForRailEnv):
"""
We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation.
We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
This is necessary so that we can pass the distance map to the ShortestPathPredictor
Here we also want to highlight how you can visualize your observation
"""
def __init__(self, predictor):
super().__init__(max_depth=0)
self.observation_space = [10]
self.predictor = predictor
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
def get_many(self, handles=None):
'''
Because we do not want to call the predictor seperately for every agent we implement the get_many function
......@@ -204,9 +204,9 @@ In contrast to the previous examples we also implement the :code:`def get_many(s
:param handles:
:return:
'''
self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map})
self.predicted_pos = {}
for t in range(len(self.predictions[0])):
pos_list = []
......@@ -215,47 +215,47 @@ In contrast to the previous examples we also implement the :code:`def get_many(s
# We transform (x,y) coodrinates to a single integer number for simpler comparison
self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
observations = {}
# Collect all the different observation for all the agents
for h in handles:
observations[h] = self.get(h)
return observations
def get(self, handle):
'''
Lets write a simple observation which just indicates whether or not the own predicted path
overlaps with other predicted paths at any time. This is useless for the task of navigation but might
help when looking for conflicts. A more complex implementation can be found in the TreeObsForRailEnv class
Each agent recieves an observation of length 10, where each element represents a prediction step and its value
is:
- 0 if no overlap is happening
- 1 where n i the number of other paths crossing the predicted cell
:param handle: handeled as an index of an agent
:return: Observation of handle
'''
observation = np.zeros(10)
# We are going to track what cells where considered while building the obervation and make them accesible
# For rendering
visited = set()
for _idx in range(10):
# Check if any of the other prediction overlap with agents own predictions
x_coord = self.predictions[handle][_idx][1]
y_coord = self.predictions[handle][_idx][2]
# We add every observed cell to the observation rendering
visited.add((x_coord, y_coord))
if self.predicted_pos[_idx][handle] in np.delete(self.predicted_pos[_idx], handle, 0):
# We detect if another agent is predicting to pass through the same cell at the same predicted time
observation[handle] = 1
# This variable will be access by the renderer to visualize the observation
self.env.dev_obs_dict[handle] = visited
return observation
We can then use this new observation builder and the renderer to visualize the observation of each agent.
......@@ -265,23 +265,23 @@ We can then use this new observation builder and the renderer to visualize the o
# Initiate the Predictor
CustomPredictor = ShortestPathPredictorForRailEnv(10)
# Pass the Predictor to the observation builder
CustomObsBuilder = ObservePredictions(CustomPredictor)
# Initiate Environment
env = RailEnv(width=10,
height=10,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
number_of_agents=3,
obs_builder_object=CustomObsBuilder)
obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
# We render the initial step and show the obsered cells as colored boxes
env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
action_dict = {}
for step in range(100):
for a in range(env.get_num_agents()):
......@@ -321,7 +321,7 @@ These two objects can be used for example to detect switches that are usable by
cell_transitions = self.env.rail.get_transitions(*position, direction)
transition_bit = bin(self.env.rail.get_full_transitions(*position))
total_transitions = transition_bit.count("1")
num_transitions = np.count_nonzero(cell_transitions)
......@@ -357,7 +357,7 @@ Beyond the basic agent information we can also access more details about the age
Similar to the speed data you can also access individual data about the malfunctions of an agent. All data is available through :code:`agent.malfunction_data` with:
- Indication how long the agent is still malfunctioning :code:`'malfunction'` by an integer counting down at each time step. 0 means the agent is ok and can move.
- Indication how long the agent is still malfunctioning :code:`'malfunction'` by an integer counting down at each time step. 0 means the agent is ok and can move.
- Possion rate at which malfunctions happen for this agent :code:`'malfunction_rate'`
- Number of steps untill next malfunction will occur :code:`'next_malfunction'`
- Number of malfunctions an agent have occured for this agent so far :code:`nr_malfunctions'`
......
......@@ -79,8 +79,8 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = self._new_position(agent.position, direction)
min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
new_position = self.new_position(agent.position, direction)
min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
......@@ -140,7 +140,7 @@ class ObservePredictions(TreeObsForRailEnv):
:return:
'''
self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map})
self.predicted_pos = {}
for t in range(len(self.predictions[0])):
......
......@@ -47,8 +47,8 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = self._new_position(agent.position, direction)
min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
new_position = self.new_position(agent.position, direction)
min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
......
......@@ -45,6 +45,8 @@ class Environment:
def __init__(self):
self.action_space = ()
self.observation_space = ()
self.distance_map_computed = False
self.distance_map = None
pass
def reset(self):
......
......@@ -2,7 +2,6 @@
Collection of environment-specific ObservationBuilder.
"""
import pprint
from collections import deque
import numpy as np
......@@ -39,8 +38,6 @@ class TreeObsForRailEnv(ObservationBuilder):
self.agents_previous_reset = None
self.tree_explored_actions = [1, 2, 3, 0]
self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
self.distance_map = None
self.distance_map_computed = False
def reset(self):
agents = self.env.agents
......@@ -52,110 +49,16 @@ class TreeObsForRailEnv(ObservationBuilder):
if agents[i].target != self.agents_previous_reset[i].target:
compute_distance_map = True
# Don't compute the distance map if it was loaded
if self.agents_previous_reset is None and self.distance_map is not None:
if self.agents_previous_reset is None and self.env.distance_map is not None:
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
compute_distance_map = False
if compute_distance_map:
self._compute_distance_map()
self.env.compute_distance_map()
self.agents_previous_reset = agents
def _compute_distance_map(self):
agents = self.env.agents
# For testing only --> To assert if a distance map need to be recomputed.
self.distance_map_computed = True
nb_agents = len(agents)
self.distance_map = np.inf * np.ones(shape=(nb_agents,
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(nb_agents)
self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
# Update local lookup table for all agents' target locations
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
def _distance_map_walker(self, position, target_nr):
"""
Utility function to compute distance maps from each cell in the rail network (and each possible
orientation within it) to each agent's target cell.
"""
# Returns max distance to target, from the farthest away node, while filling in distance_map
self.distance_map[target_nr, position[0], position[1], :] = 0
# Fill in the (up to) 4 neighboring nodes
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
(position[0], position[1], 3)}
max_distance = 0
while nodes_queue:
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
if node_id not in visited:
visited.add(node_id)
# From the list of possible neighbors that have at least a path to the current node, only keep those
# whose new orientation in the current cell would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors) > 0:
max_distance = max(max_distance, node[3] + 1)
return max_distance
def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
"""
neighbors = []
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction + 2) % 4]
for neigh_direction in possible_directions:
new_cell = self._new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
if isNextCellDeadEnd:
movement = (desired_movement_from_new_cell+2) % 4
"""
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
current_distance + 1)
neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
return neighbors
def _new_position(self, position, movement):
def new_position(self, position, movement):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
......@@ -180,7 +83,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.max_prediction_depth = 0
self.predicted_pos = {}
self.predicted_dir = {}
self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map})
if self.predictions:
for t in range(len(self.predictions[0])):
......@@ -276,7 +179,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Root node - current position
# Here information about the agent itself is stored
observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0,
observation = [0, 0, 0, 0, 0, 0, self.env.distance_map[(handle, *agent.position, agent.direction)], 0, 0,
agent.malfunction_data['malfunction'], agent.speed_data['speed']]
visited = set()
......@@ -291,7 +194,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
new_cell = self.new_position(agent.position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
observation = observation + branch_observation
......@@ -464,7 +367,7 @@ class TreeObsForRailEnv(ObservationBuilder):
exploring = True
# convert one-hot encoding to 0,1,2,3
direction = np.argmax(cell_transitions)
position = self._new_position(position, direction)
position = self.new_position(position, direction)
num_steps += 1
tot_dist += 1
elif num_transitions > 0:
......@@ -506,7 +409,7 @@ class TreeObsForRailEnv(ObservationBuilder):
potential_conflict,
unusable_switch,
np.inf,
self.distance_map[handle, position[0], position[1], direction],
self.env.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
......@@ -520,7 +423,7 @@ class TreeObsForRailEnv(ObservationBuilder):
potential_conflict,
unusable_switch,
tot_dist,
self.distance_map[handle, position[0], position[1], direction],
self.env.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
......@@ -537,7 +440,7 @@ class TreeObsForRailEnv(ObservationBuilder):
(branch_direction + 2) % 4):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back
new_cell = self._new_position(position, (branch_direction + 2) % 4)
new_cell = self.new_position(position, (branch_direction + 2) % 4)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
(branch_direction + 2) % 4,
......@@ -547,7 +450,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if len(branch_visited) != 0:
visited = visited.union(branch_visited)
elif last_is_switch and possible_transitions[branch_direction]:
new_cell = self._new_position(position, branch_direction)
new_cell = self.new_position(position, branch_direction)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
branch_direction,
......
......@@ -5,6 +5,7 @@ Definition of the RailEnv environment.
import warnings
from enum import IntEnum
from typing import List
from collections import deque
import msgpack
import msgpack_numpy as m
......@@ -143,6 +144,7 @@ class RailEnv(Environment):
file_name: you can load a pickle file. from previously saved *.pkl file
"""
super().__init__()
self.rail_generator: RailGenerator = rail_generator
self.schedule_generator: ScheduleGenerator = schedule_generator
......@@ -233,7 +235,7 @@ class RailEnv(Environment):
rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
if optionals and 'distance_maps' in optionals:
self.obs_builder.distance_map = optionals['distance_maps']
self.distance_map = optionals['distance_maps']
if regen_rail or self.rail is None:
self.rail = rail
......@@ -573,8 +575,8 @@ class RailEnv(Environment):
# agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
if hasattr(self.obs_builder, 'distance_map') and "distance_maps" in data.keys():
self.obs_builder.distance_map = data["distance_maps"]
if hasattr(self, 'distance_map') and "distance_maps" in data.keys():
self.distance_map = data["distance_maps"]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
......@@ -588,8 +590,8 @@ class RailEnv(Environment):
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
if hasattr(self.obs_builder, 'distance_map'):
distance_map_data = self.obs_builder.distance_map
if hasattr(self, 'distance_map'):
distance_map_data = self.distance_map
msgpack.packb(distance_map_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
......@@ -605,8 +607,8 @@ class RailEnv(Environment):
return msgpack.packb(msg_data, use_bin_type=True)
def save(self, filename):
if hasattr(self.obs_builder, 'distance_map'):
if len(self.obs_builder.distance_map) > 0:
if hasattr(self, 'distance_map') and self.distance_map is not None:
if len(self.distance_map) > 0:
with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_dist_msg())
else:
......@@ -617,7 +619,7 @@ class RailEnv(Environment):
file_out.write(self.get_full_state_msg())
def load(self, filename):
if hasattr(self.obs_builder, 'distance_map'):
if hasattr(self, 'distance_map'):
with open(filename, "rb") as file_in:
load_data = file_in.read()
self.set_full_state_dist_msg(load_data)
......@@ -633,3 +635,97 @@ class RailEnv(Environment):
from importlib_resources import read_binary
load_data = read_binary(package, resource)
self.set_full_state_msg(load_data)
def compute_distance_map(self):
agents = self.agents
# For testing only --> To assert if a distance map need to be recomputed.
self.distance_map_computed = True
nb_agents = len(agents)
self.distance_map = np.inf * np.ones(shape=(nb_agents,
self.height,
self.width,
4))
max_dist = np.zeros(nb_agents)
max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
# Update local lookup table for all agents' target locations
self.obs_builder.location_has_target = {tuple(agent.target): 1 for agent in agents}
def _distance_map_walker(self, position, target_nr):
"""
Utility function to compute distance maps from each cell in the rail network (and each possible
orientation within it) to each agent's target cell.
"""
# Returns max distance to target, from the farthest away node, while filling in distance_map
self.distance_map[target_nr, position[0], position[1], :] = 0
# Fill in the (up to) 4 neighboring nodes
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
(position[0], position[1], 3)}
max_distance = 0
while nodes_queue:
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
if node_id not in visited:
visited.add(node_id)
# From the list of possible neighbors that have at least a path to the current node, only keep those
# whose new orientation in the current cell would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors) > 0:
max_distance = max(max_distance, node[3] + 1)
return max_distance
def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
"""
neighbors = []
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction + 2) % 4]
for neigh_direction in possible_directions:
new_cell = self.obs_builder.new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.height and new_cell[1] >= 0 and new_cell[1] < self.width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is_valid = self.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
if isNextCellDeadEnd:
movement = (desired_movement_from_new_cell+2) % 4
"""
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
current_distance + 1)
neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
return neighbors
......@@ -43,9 +43,8 @@ def test_walker():
# reset to set agents from agents_static
env.reset(False, False)
obs_builder: TreeObsForRailEnv = env.obs_builder
print(obs_builder.distance_map[(0, *[0, 1], 1)])
assert obs_builder.distance_map[(0, *[0, 1], 1)] == 3
print(obs_builder.distance_map[(0, *[0, 2], 3)])
assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2
print(env.distance_map[(0, *[0, 1], 1)])
assert env.distance_map[(0, *[0, 1], 1)] == 3
print(env.distance_map[(0, *[0, 2], 3)])
assert env.distance_map[(0, *[0, 2], 1)] == 2
......@@ -51,7 +51,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
shortest_distance = np.inf