Commit d714be4d authored by Christian Eichenberger's avatar Christian Eichenberger 🏸
Browse files

Merge branch 'feature/observation-multiple-agents' into 'master'

Feature/observation multiple agents

Closes #67

See merge request flatland/flatland!94
parents 24cb4f84 15a725f0
Pipeline #1349 passed with stages
in 9 minutes and 8 seconds
import runpy
import sys
from io import StringIO
from test.support import swap_attr
from time import sleep
import importlib_resources
......@@ -9,6 +8,8 @@ import pkg_resources
from benchmarker import Benchmarker
from importlib_resources import path
from benchmarks.benchmark_utils import swap_attr
for entry in [entry for entry in importlib_resources.contents('examples') if
not pkg_resources.resource_isdir('examples', entry)
and entry.endswith(".py")
......
......@@ -19,7 +19,7 @@ env = RailEnv(width=7,
# Print the observation vector for agent 0
obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()):
env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=7)
env.obs_builder.util_print_obs_subtree(tree=obs[i])
env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True, frames=True)
......
......@@ -11,6 +11,13 @@ class Grid4TransitionsEnum(IntEnum):
SOUTH = 2
WEST = 3
@staticmethod
def to_char(int: int):
return {0: 'N',
1: 'E',
2: 'S',
3: 'W'}[int]
class Grid4Transitions(Transitions):
"""
......
"""
Collection of environment-specific ObservationBuilder.
"""
import pprint
from collections import deque
import numpy as np
......@@ -19,6 +20,8 @@ class TreeObsForRailEnv(ObservationBuilder):
network to simplify the representation of the state of the environment for each agent.
"""
observation_dim = 9
def __init__(self, max_depth, predictor=None):
self.max_depth = max_depth
......@@ -28,12 +31,13 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth + 1):
size += pow4
pow4 *= 4
self.observation_dim = 9
self.observation_space = [size * self.observation_dim]
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.predictor = predictor
self.agents_previous_reset = None
self.tree_explored_actions = [1, 2, 3, 0]
self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
def reset(self):
agents = self.env.agents
......@@ -126,19 +130,6 @@ class TreeObsForRailEnv(ObservationBuilder):
desired_movement_from_new_cell = (neigh_direction + 2) % 4
"""
# Is the next cell a dead-end?
isNextCellDeadEnd = False
nbits = 0
tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# Dead-end!
isNextCellDeadEnd = True
"""
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
......@@ -213,7 +204,7 @@ class TreeObsForRailEnv(ObservationBuilder):
[... from 'right] +
[... from 'back']
Finally, each node information is composed of 5 floating point values:
Finally, each node information is composed of 8 floating point values:
#1: if own target lies on the explored branch the current distance from the agent in number of cells is stored.
......@@ -240,7 +231,7 @@ class TreeObsForRailEnv(ObservationBuilder):
(possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction
#9: agent in the opposite drection
#9: agent in the opposite direction
n = number of agents present other direction than myself (so conflict)
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself
......@@ -273,7 +264,6 @@ class TreeObsForRailEnv(ObservationBuilder):
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# If only one transition is possible, the tree is oriented with this transition as the forward branch.
# TODO: Test if this works as desired!
orientation = agent.direction
if num_transitions == 1:
......@@ -287,15 +277,20 @@ class TreeObsForRailEnv(ObservationBuilder):
observation = observation + branch_observation
visited = visited.union(branch_visited)
else:
num_cells_to_fill_in = 0
pow4 = 1
for i in range(self.max_depth):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
# add cells filled with infinity if no transition is possible
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
self.env.dev_obs_dict[handle] = visited
return observation
def _num_cells_to_fill_in(self, remaining_depth):
"""Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
num_observations = 0
pow4 = 1
for i in range(remaining_depth):
num_observations += pow4
pow4 *= 4
return num_observations * self.observation_dim
def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth):
"""
Utility function to compute tree-based observations.
......@@ -343,7 +338,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Cummulate the number of agents on branch with other direction
other_agent_opposite_direction += 1
# Register possible conflict
# Register possible future conflict
if self.predictor and num_steps < self.max_prediction_depth:
int_position = coordinate_to_position(self.env.width, [position])
if tot_dist < self.max_prediction_depth:
......@@ -505,41 +500,47 @@ class TreeObsForRailEnv(ObservationBuilder):
if len(branch_visited) != 0:
visited = visited.union(branch_visited)
else:
num_cells_to_fill_in = 0
pow4 = 1
for i in range(self.max_depth - depth):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
# no exploring possible, add just cells with infinity
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
return observation, visited
def util_print_obs_subtree(self, tree, num_features_per_node=9, prompt='', current_depth=0):
def util_print_obs_subtree(self, tree):
"""
Utility function to pretty-print tree observations returned by this object.
"""
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(self.unfold_observation_tree(tree))
def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
"""
Utility function to pretty-print tree observations returned by this object.
"""
if len(tree) < num_features_per_node:
if len(tree) < self.observation_dim:
return
depth = 0
tmp = len(tree) / num_features_per_node - 1
tmp = len(tree) / self.observation_dim - 1
pow4 = 4
while tmp > 0:
tmp -= pow4
depth += 1
pow4 *= 4
prompt_ = ['L:', 'F:', 'R:', 'B:']
print(" " * current_depth + prompt, tree[0:num_features_per_node])
child_size = (len(tree) - num_features_per_node) // 4
for children in range(4):
child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)]
self.util_print_obs_subtree(child_tree,
num_features_per_node,
prompt=prompt_[children],
current_depth=current_depth + 1)
unfolded = {}
unfolded[''] = tree[0:self.observation_dim]
child_size = (len(tree) - self.observation_dim) // 4
for child in range(4):
child_tree = tree[(self.observation_dim + child * child_size):
(self.observation_dim + (child + 1) * child_size)]
observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
if observation_tree is not None:
if actions_for_display:
label = self.tree_explorted_actions_char[child]
else:
label = self.tree_explored_actions[child]
unfolded[label] = observation_tree
return unfolded
def _set_env(self, env):
self.env = env
......@@ -708,8 +709,6 @@ class LocalObsForRailEnv(ObservationBuilder):
bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist)
# self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array(
# list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
def get(self, handle):
agents = self.env.agents
......
......@@ -19,12 +19,22 @@ from flatland.envs.observations import TreeObsForRailEnv
class RailEnvActions(IntEnum):
DO_NOTHING = 0
DO_NOTHING = 0 # implies change of direction in a dead-end!
MOVE_LEFT = 1
MOVE_FORWARD = 2
MOVE_RIGHT = 3
STOP_MOVING = 4
@staticmethod
def to_char(a: int):
return {
0: 'B',
1: 'L',
2: 'F',
3: 'R',
4: 'S',
}[a]
class RailEnv(Environment):
"""
......
......@@ -48,11 +48,11 @@
obs_builder_object=TreeObsForRailEnv(max_depth=2))
# Print the observation vector for agent 0
obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()):
env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=7)
env.obs_builder.util_print_obs_subtree(tree=obs[i])
env_renderer = RenderTool(env, gl="PIL")
# env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.renderEnv(show=True, frames=True)
......
import numpy as np
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap
def make_simple_rail():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ /_\ _ _ _ _ _ _
# \ /
# |
# |
# |
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
......@@ -3,62 +3,17 @@
import numpy as np
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from tests.simple_rail import make_simple_rail
"""Tests for `flatland` package."""
def test_global_obs():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ /_\ _ _ _ _ _ _
# \ /
# |
# |
# |
rail, rail_map = make_simple_rail()
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pprint
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from tests.simple_rail import make_simple_rail
"""Test predictions for `flatland` package."""
def make_simple_rail():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ /_\ _ _ _ _ _ _
# \ /
# |
# |
# |
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
def test_dummy_predictor(rendering=False):
rail, rail_map = make_simple_rail()
......@@ -67,12 +24,16 @@ def test_dummy_predictor(rendering=False):
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
)
# reset to initialize agents_static
env.reset()
# set initial position and direction for testing...
env.agents[0].position = (5, 6)
env.agents[0].direction = 0
env.agents[0].target = (3, 0)
env.agents_static[0].position = (5, 6)
env.agents_static[0].direction = 0
env.agents_static[0].target = (3, 0)
# reset to set agents from agents_static
env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
......@@ -153,40 +114,38 @@ def test_shortest_path_predictor(rendering=False):
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# reset to initialize agents_static
env.reset()
agent = env.agents[0]
# set the initial position
agent = env.agents_static[0]
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
# reset to set agents from agents_static
env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=False)
input("Continue?")
agent = env.agents[0]
assert agent.position == (5, 6)
assert agent.direction == 0
assert agent.target == (3, 9)
assert agent.moving
env.obs_builder._compute_distance_map()
# compute the observations and predictions
distance_map = env.obs_builder.distance_map
assert distance_map[agent.handle, agent.position[0], agent.position[
assert distance_map[0, agent.position[0], agent.position[
1], agent.direction] == 5.0, "found {} instead of {}".format(
distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
# test assertions
env.obs_builder.get_many()
# extract the data
predictions = env.obs_builder.predictions
positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
# test if data meets expectations
expected_positions = [
[5, 6],
[4, 6],
......@@ -264,3 +223,59 @@ def test_shortest_path_predictor(rendering=False):
"directions {}, expected {}".format(directions, expected_directions)
assert np.array_equal(time_offsets, expected_time_offsets), \
"time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
def test_shortest_path_predictor_conflicts(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# initialize agents_static
env.reset()
# set the initial position
agent = env.agents_static[0]
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
agent = env.agents_static[1]
agent.position = (3, 8) # east dead-end
agent.direction = 3 # west
agent.target = (6, 6) # south dead-end
agent.moving = True
# reset to set agents from agents_static
observations = env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=False)
input("Continue?")
# get the trees to test
obs_builder: TreeObsForRailEnv = env.obs_builder
pp = pprint.PrettyPrinter(indent=4)
tree_0 = obs_builder.unfold_observation_tree(observations[0])
tree_1 = obs_builder.unfold_observation_tree(observations[1])
pp.pprint(tree_0)
# check the expectations
expected_conflicts_0 = [('F', 'R')]
expected_conflicts_1 = [('F', 'L')]
_check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ")
_check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''):
assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt)
for a_1 in obs_builder.tree_explorted_actions_char:
conflict = tree_0[a_1][''][8]
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
for a_2 in obs_builder.tree_explorted_actions_char:
conflict = tree_0[a_1][a_2][''][8]
assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment