Skip to content
Snippets Groups Projects
Commit c07e281f authored by gmollard's avatar gmollard
Browse files

basic test work

parents 493b8a20 c794eb4e
No related branches found
No related tags found
Loading
import numpy as np import numpy as np
from collections import deque
# TODO: add docstrings, pylint, etc... # TODO: add docstrings, pylint, etc...
...@@ -15,15 +17,131 @@ class ObservationBuilder: ...@@ -15,15 +17,131 @@ class ObservationBuilder:
class TreeObsForRailEnv(ObservationBuilder): class TreeObsForRailEnv(ObservationBuilder):
def __init__(self, env):
self.env = env
def reset(self): def reset(self):
# TODO: precompute distances, etc... self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents,
# raise NotImplementedError() self.env.height,
pass self.env.width))
self.max_dist = np.zeros(self.env.number_of_agents)
for i in range(self.env.number_of_agents):
self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
def _distance_map_walker(self, position, target_nr):
# Returns max distance to target, from the farthest away node, while filling in distance_map
for ori in range(4):
self.distance_map[target_nr, position[0], position[1]] = 0
# Fill in the (up to) 4 neighboring nodes
# nodes_queue = [] # list of tuples (row, col, direction, distance);
# direction is the direction of movement, meaning that at least a possible orientation
# of an agent in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(position,
target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = set([(position[0], position[1], 0), (position[0], position[1], 1),
(position[0], position[1], 2), (position[0], position[1], 3)])
max_distance = 0
while nodes_queue:
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
#print(node_id, visited, (node_id in visited))
#print(nodes_queue)
if node_id not in visited:
visited.add(node_id)
# From the list of possible neighbors that have at least a path to the
# current node, only keep those whose new orientation in the current cell
# would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors(
(node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors)>0:
max_distance = max(max_distance, node[3]+1)
return max_distance
def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
neighbors = []
for direction in range(4):
new_cell = self._new_position(position, (direction+2)%4)
if new_cell[0]>=0 and new_cell[0]<self.env.height and\
new_cell[1]>=0 and new_cell[1]<self.env.width:
# Check if the two cells are connected by a valid transition
transitionValid = False
for orientation in range(4):
moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
if moves[direction]:
transitionValid = True
break
if not transitionValid:
continue
# Check if a transition in direction node[2] is possible if an agent
# lands in the current cell with orientation `direction'; this only
# applies to cells that are not dead-ends!
directionMatch = True
if enforce_target_direction>=0:
directionMatch = self.env.rail.get_transition(
(new_cell[0], new_cell[1], direction), enforce_target_direction)
# If transition is found to invalid, check if perhaps it
# is a dead-end, in which case the direction of movement is rotated
# 180 degrees (moving forward turns the agents and makes it step in the previous cell)
if not directionMatch:
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# Dead-end!
# Check if transition is possible in new_cell
# with orientation (direction+2)%4 in direction `direction'
directionMatch = directionMatch or self.env.rail.get_transition(
(new_cell[0], new_cell[1], (direction+2)%4), direction)
if transitionValid and directionMatch:
new_distance = min(self.distance_map[target_nr,
new_cell[0], new_cell[1]], current_distance+1)
neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance
return neighbors
def _new_position(self, position, movement):
if movement == 0: # NORTH
return (position[0]-1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0]+1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
def get(self, handle): def get(self, handle):
# TODO: compute the observation for agent `handle' # TODO: compute the observation for agent `handle'
# raise NotImplementedError()
return [] return []
...@@ -38,12 +156,235 @@ class GlobalObsForRailEnv(ObservationBuilder): ...@@ -38,12 +156,235 @@ class GlobalObsForRailEnv(ObservationBuilder):
- Four 2D arrays containing respectively the position of the given agent, - Four 2D arrays containing respectively the position of the given agent,
the position of its target, the positions of the other agents and of the position of its target, the positions of the other agents and of
their target. their target.
- A 4 elements array with one of encoding of the direction.
""" """
def __init__(self, env): def __init__(self, env):
super(GlobalObsForRailEnv, self).__init__(env) super(GlobalObsForRailEnv, self).__init__(env)
def reset(self):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]): for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]): for j in range(self.rail_obs.shape[1]):
self.rail_obs[i, j] = self.env.rail.get_transitions((i, j)) self.rail_obs[i, j] = np.array(
list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
# self.targets = np.zeros(self.env.height, self.env.width)
# for target_pos in self.env.agents_target:
# self.targets[target_pos] += 1
def get(self, handle):
obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width))
agent_pos = self.env.agents_position[handle]
obs_agents_targets_pos[0][agent_pos] += 1
for i in range(len(self.env.agents_position)):
if i != handle:
obs_agents_targets_pos[3][self.env.agents_position[i]] += 1
agent_target_pos = self.env.agents_target[handle]
obs_agents_targets_pos[1][agent_target_pos] += 1
for i in range(len(self.env.agents_target)):
if i != handle:
obs_agents_targets_pos[2][self.env.agents_target[i]] += 1
direction = np.zeros(4)
direction[self.env.agents_direction[handle]] = 1
return self.rail_obs, obs_agents_targets_pos, direction
"""
def get_observation(self, agent):
# Get the current observation for an agent
current_position = self.internal_position[agent]
#target_heading = self._compass(agent).tolist()
coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
agent_distance = self.distance_map[agent][coordinate][0]
# Start tree search
if current_position == self.target[agent]:
agent_tree = Node(current_position, [-np.inf, -np.inf, -np.inf, -np.inf, -1])
else:
agent_tree = Node(current_position, [0, 0, 0, 0, agent_distance])
initial_tree_state = Tree_State(agent, current_position, -1, 0, 0)
self._tree_search(initial_tree_state, agent_tree, agent)
observation = []
distance_data = []
self._flatten_tree(agent_tree, observation, distance_data, self.max_depth+1)
# This is probably very slow!!!!
#max_obs = np.max([i for i in observation if i < np.inf])
#if max_obs != 0:
# observation = np.array(observation)/ max_obs
#print([i for i in distance_data if i >= 0])
observation = np.concatenate((observation, distance_data))
#observation = np.concatenate((observation, np.identity(5)[int(self.last_action[agent])]))
#return np.clip(observation / self.max_dist[agent], -1, 1)
return np.clip(observation / 15., -1, 1)
def _tree_search(self, in_tree_state, parent_node, agent):
if in_tree_state.depth >= self.max_depth:
return
target_distance = np.inf
other_target = np.inf
other_agent = np.inf
coordinate = tuple(np.transpose(self._position_to_coordinate([in_tree_state.position])))
curr_target_dist = self.distance_map[agent][coordinate][0]
forbidden_action = (in_tree_state.direction + 2) % 4
# Update the position
failed_move = 0
leaf_distance = in_tree_state.distance
for child_idx in range(4):
if child_idx != forbidden_action or in_tree_state.direction == -1:
tree_state = copy.deepcopy(in_tree_state)
tree_state.direction = child_idx
current_position, invalid_move = self._detect_path(
tree_state.position, tree_state.direction)
if tree_state.initial_direction == None:
tree_state.initial_direction = child_idx
if not invalid_move:
coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
curr_target_dist = self.distance_map[agent][coordinate][0]
#if tree_state.initial_direction == None:
# tree_state.initial_direction = child_idx
tree_state.position = current_position
tree_state.distance += 1
# Collect information at the current position
detection_distance = tree_state.distance
if current_position == self.target[tree_state.agent]:
target_distance = detection_distance
elif current_position in self.target:
other_target = detection_distance
if current_position in self.internal_position:
other_agent = detection_distance
tree_state.data[0] = self._min_greater_zero(target_distance, tree_state.data[0])
tree_state.data[1] = self._min_greater_zero(other_target, tree_state.data[1])
tree_state.data[2] = self._min_greater_zero(other_agent, tree_state.data[2])
tree_state.data[3] = tree_state.distance
tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
if self._switch_detection(tree_state.position):
tree_state.depth += 1
new_tree_state = copy.deepcopy(tree_state)
new_node = parent_node.insert(tree_state.position,
tree_state.data, tree_state.initial_direction)
new_tree_state.initial_direction = None
new_tree_state.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
self._tree_search(new_tree_state, new_node, agent)
else:
self._tree_search(tree_state, parent_node, agent)
else:
failed_move += 1
if failed_move == 3 and in_tree_state.direction != -1:
tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction)
return
return
def _flatten_tree(self, node, observation_vector, distance_sensor, depth):
if depth <= 0:
return
if node != None:
observation_vector.extend(node.data[:-1])
distance_sensor.extend([node.data[-1]])
else:
observation_vector.extend([-np.inf, -np.inf, -np.inf, -np.inf])
distance_sensor.extend([-np.inf])
for child_idx in range(4):
if node != None:
child = node.children[child_idx]
else:
child = None
self._flatten_tree(child, observation_vector, distance_sensor, depth -1)
def _switch_detection(self, position):
# Hack to detect switches
# This can later directly be derived from the transition matrix
paths = 0
for i in range(4):
_, invalid_move = self._detect_path(position, i)
if not invalid_move:
paths +=1
if paths >= 3:
return True
return False
def _min_greater_zero(self, x, y):
if x <= 0 and y <= 0:
return 0
if x < 0:
return y
if y < 0:
return x
return min(x, y)
"""
class Tree_State:
"""
Keep track of the current state while building the tree
"""
def __init__(self, agent, position, direction, depth, distance):
self.agent = agent
self.position = position
self.direction = direction
self.depth = depth
self.initial_direction = None
self.distance = distance
self.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
class Node():
"""
Define a tree node to get populated during search
"""
def __init__(self, position, data):
self.n_children = 4
self.children = [None]*self.n_children
self.data = data
self.position = position
def insert(self, position, data, child_idx):
"""
Insert new node with data
@param data node data object to insert
"""
new_node = Node(position, data)
self.children[child_idx] = new_node
return new_node
def print_tree(self, i=0, depth = 0):
"""
Print tree content inorder
"""
current_i = i
curr_depth = depth+1
if i < self.n_children:
if self.children[i] != None:
self.children[i].print_tree(depth=curr_depth)
current_i += 1
if self.children[i] != None:
self.children[i].print_tree(i, depth=curr_depth)
...@@ -2,8 +2,11 @@ ...@@ -2,8 +2,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from flatland.core.env_observation_builder import GlobalObsForRailEnv from flatland.core.env_observation_builder import GlobalObsForRailEnv
from flatland.core.transitions import Grid4Transitions # from flatland.core.transitions import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.core.env import RailEnv
import numpy as np import numpy as np
from flatland.utils.rendertools import *
"""Tests for `flatland` package.""" """Tests for `flatland` package."""
...@@ -43,18 +46,44 @@ def test_global_obs(): ...@@ -43,18 +46,44 @@ def test_global_obs():
double_switch_north_horizontal_straight = transitions.rotate_transition( double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180) double_switch_south_horizontal_straight, 180)
rail_map = np.array( rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6]*2 +
[[horizontal_straight] * 3 + [double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 3] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[empty] * 3 + [dead_end_from_south] + [empty] * 6], dtype=np.uint16) [[dead_end_from_east] + [horizontal_straight] * 2 +
[double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
env = RailEnv(rail, number_of_agents=1)
env.reset()
# env_renderer = RenderTool(env)
# env_renderer.renderEnv(show=True)
global_obs = GlobalObsForRailEnv(env)
global_obs.reset()
assert(global_obs.rail_obs.shape == rail_map.shape + (16,))
rail_map_recons = np.zeros_like(rail_map)
for i in range(global_obs.rail_obs.shape[0]):
for j in range(global_obs.rail_obs.shape[1]):
rail_map_recons[i,j] = int(
''.join(global_obs.rail_obs[i, j].astype(int).astype(str)), 2)
assert(rail_map_recons.all() == rail_map.all())
obs = global_obs.get(0)
# If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell
assert(np.sum(rail_map * obs[1][0]) > 0)
print(rail_map.shape)
test_global_obs() test_global_obs()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment