"""
ObservationBuilder objects are objects that can be passed to environments designed for customizability.
The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle).
+ Reset() is called after each environment reset, to allow for pre-computing relevant data.
+ Get() is called whenever an observation has to be computed, potentially for each agent independently in
case of multi-agent environments.
"""
import numpy as np
from collections import deque
class ObservationBuilder:
"""
ObservationBuilder base class.
"""
def __init__(self):
pass
def _set_env(self, env):
self.env = env
def reset(self):
"""
Called after each environment reset.
"""
raise NotImplementedError()
def get(self, handle=0):
"""
Called whenever an observation has to be computed for the `env' environment, possibly
for each agent independently (agent id `handle').
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
An observation structure, specific to the corresponding environment.
"""
raise NotImplementedError()
class TreeObsForRailEnv(ObservationBuilder):
"""
TreeObsForRailEnv object.
This object returns observation vectors for agents in the RailEnv environment.
The information is local to each agent and exploits the tree structure of the rail
network to simplify the representation of the state of the environment for each agent.
"""
def __init__(self, max_depth):
self.max_depth = max_depth
def reset(self):
self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents,
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(self.env.number_of_agents)
for i in range(self.env.number_of_agents):
self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
# Update local lookup table for all agents' target locations
self.location_has_target = {}
for loc in self.env.agents_target:
self.location_has_target[(loc[0], loc[1])] = 1
def _distance_map_walker(self, position, target_nr):
"""
Utility function to compute distance maps from each cell in the rail network (and each possible
orientation within it) to each agent's target cell.
"""
# Returns max distance to target, from the farthest away node, while filling in distance_map
for ori in range(4):
self.distance_map[target_nr, position[0], position[1], ori] = 0
# Fill in the (up to) 4 neighboring nodes
# nodes_queue = [] # list of tuples (row, col, direction, distance);
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = set([(position[0], position[1], 0),
(position[0], position[1], 1),
(position[0], position[1], 2),
(position[0], position[1], 3)])
max_distance = 0
while nodes_queue:
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
if node_id not in visited:
visited.add(node_id)
# From the list of possible neighbors that have at least a path to the current node, only keep those
# whose new orientation in the current cell would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors) > 0:
max_distance = max(max_distance, node[3] + 1)
return max_distance
def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
"""
neighbors = []
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction + 2) % 4]
for neigh_direction in possible_directions:
new_cell = self._new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
new_cell[1] >= 0 and new_cell[1] < self.env.width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
"""
# Is the next cell a dead-end?
isNextCellDeadEnd = False
nbits = 0
tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# Dead-end!
isNextCellDeadEnd = True
"""
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
if isValid:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
if isNextCellDeadEnd:
movement = (desired_movement_from_new_cell+2) % 4
"""
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
current_distance + 1)
neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
return neighbors
def _new_position(self, position, movement):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0] + 1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
def get(self, handle):
"""
Computes the current observation for agent `handle' in env
The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
the transitions. The order is:
[data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
Each branch data is organized as:
[root node information] +
[recursive branch data from 'left'] +
[... from 'forward'] +
[... from 'right] +
[... from 'back']
Finally, each node information is composed of 5 floating point values:
#1:
#2: 1 if a target of another agent is detected between the previous node and the current one.
#3: 1 if another agent is detected between the previous node and the current one.
#4: distance of agent to the current branch node
#5: minimum distance from node to the agent's target (when landing to the node following the corresponding
branch.
Missing/padding nodes are filled in with -inf (truncated).
Missing values in present node are filled in with +inf (truncated).
In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
In case the target node is reached, the values are [0, 0, 0, 0, 0].
"""
# Update local lookup table for all agents' positions
self.location_has_agent = {}
for loc in self.env.agents_position:
self.location_has_agent[(loc[0], loc[1])] = 1
position = self.env.agents_position[handle]
orientation = self.env.agents_direction[handle]
# Root node - current position
observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
root_observation = observation[:]
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# TODO: Adjust this to the novel movement dynamics --> Only Forward present when one transition is possible.
for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
observation = observation + branch_observation
else:
num_cells_to_fill_in = 0
pow4 = 1
for i in range(self.max_depth):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
return observation
def _explore_branch(self, handle, position, direction, root_observation, depth):
"""
Utility function to compute tree-based observations.
"""
# [Recursive branch opened]
if depth >= self.max_depth + 1:
return []
# Continue along direction until next switch or
# until no transitions are possible along the current direction (i.e., dead-ends)
# We treat dead-ends as nodes, instead of going back, to avoid loops
exploring = True
last_isSwitch = False
last_isDeadEnd = False
last_isTerminal = False # wrong cell OR cycle; either way, we don't want the agent to land here
last_isTarget = False
visited = set()
other_agent_encountered = False
other_target_encountered = False
num_steps = 1
while exploring:
# #############################
# #############################
# Modify here to compute any useful data required to build the end node's features. This code is called
# for each cell visited between the previous branching node and the next switch / target / dead-end.
if position in self.location_has_agent:
other_agent_encountered = True
if position in self.location_has_target:
other_target_encountered = True
# #############################
# #############################
if (position[0], position[1], direction) in visited:
last_isTerminal = True
break
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
last_isTarget = True
break
cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
num_transitions = 0
for i in range(4):
if cell_transitions[i]:
num_transitions += 1
exploring = False
if num_transitions == 1:
# Check if dead-end, or if we can go forward along direction
nbits = 0
tmp = self.env.rail.get_transitions((position[0], position[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# Dead-end!
last_isDeadEnd = True
if not last_isDeadEnd:
# Keep walking through the tree along `direction'
exploring = True
# TODO: Remove below calculation, this is computed already above and could be reused
for i in range(4):
if cell_transitions[i]:
position = self._new_position(position, i)
direction = i
num_steps += 1
break
elif num_transitions > 0:
# Switch detected
last_isSwitch = True
break
elif num_transitions == 0:
# Wrong cell type, but let's cover it and treat it as a dead-end, just in case
print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
position[1], direction)
last_isTerminal = True
break
# `position' is either a terminal node or a switch
observation = []
# #############################
# #############################
# Modify here to append new / different features for each visited cell!
if last_isTarget:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
root_observation[3] + num_steps,
0]
elif last_isTerminal:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
np.inf,
np.inf]
else:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction]]
# #############################
# #############################
new_root_observation = observation[:]
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction),
(branch_direction + 2) % 4):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back
new_cell = self._new_position(position, (branch_direction + 2) % 4)
branch_observation = self._explore_branch(handle,
new_cell,
(branch_direction + 2) % 4,
new_root_observation,
depth + 1)
observation = observation + branch_observation
elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
(branch_direction + 2) % 4):
new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle,
new_cell,
branch_direction,
new_root_observation,
depth + 1)
observation = observation + branch_observation
else:
num_cells_to_fill_in = 0
pow4 = 1
for i in range(self.max_depth - depth):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
return observation
def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
"""
Utility function to pretty-print tree observations returned by this object.
"""
if len(tree) < num_features_per_node:
return
depth = 0
tmp = len(tree) / num_features_per_node - 1
pow4 = 4
while tmp > 0:
tmp -= pow4
depth += 1
pow4 *= 4
prompt_ = ['L:', 'F:', 'R:', 'B:']
print(" " * current_depth + prompt, tree[0:num_features_per_node])
child_size = (len(tree) - num_features_per_node) // 4
for children in range(4):
child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)]
self.util_print_obs_subtree(child_tree,
num_features_per_node,
prompt=prompt_[children],
current_depth=current_depth + 1)
class GlobalObsForRailEnv(ObservationBuilder):
"""
Gives a global observation of the entire rail environment.
The observation is composed of the following elements:
- transition map array with dimensions (env.height, env.width, 16),
assuming 16 bits encoding of transitions.
- Four 2D arrays containing respectively the position of the given agent,
the position of its target, the positions of the other agents and of
their target.
- A 4 elements array with one of encoding of the direction.
"""
def __init__(self):
super(GlobalObsForRailEnv, self).__init__()
def reset(self):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
self.rail_obs[i, j] = np.array(
list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
# self.targets = np.zeros(self.env.height, self.env.width)
# for target_pos in self.env.agents_target:
# self.targets[target_pos] += 1
def get(self, handle):
obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width))
agent_pos = self.env.agents_position[handle]
obs_agents_targets_pos[0][agent_pos] += 1
for i in range(len(self.env.agents_position)):
if i != handle:
obs_agents_targets_pos[3][self.env.agents_position[i]] += 1
agent_target_pos = self.env.agents_target[handle]
obs_agents_targets_pos[1][agent_target_pos] += 1
for i in range(len(self.env.agents_target)):
if i != handle:
obs_agents_targets_pos[2][self.env.agents_target[i]] += 1
direction = np.zeros(4)
direction[self.env.agents_direction[handle]] = 1
return self.rail_obs, obs_agents_targets_pos, direction