Skip to content
Snippets Groups Projects
Commit d8071638 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch 'master' into 188_refining_generator

parents 391ca4a3 b82c5362
No related branches found
No related tags found
No related merge requests found
Showing
with 191 additions and 153 deletions
......@@ -160,7 +160,7 @@ Once we are set with the environment we can load our preferred agent from either
.. code-block:: python
agent = RandomAgent(env.action_space, env.observation_space)
agent = RandomAgent(state_size, action_size)
We start every trial by resetting the environment
......
......@@ -18,7 +18,7 @@ base class and must implement two methods, :code:`reset(self)` and :code:`get(se
.. _`flatland.core.env_observation_builder.ObservationBuilder` : https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/core/env_observation_builder.py#L13
Below is a simple example that returns observation vectors of size :code:`observation_space = 5` featuring only the ID (handle) of the agent whose
Below is a simple example that returns observation vectors of size 5 featuring only the ID (handle) of the agent whose
observation vector is being computed:
.. code-block:: python
......@@ -28,14 +28,12 @@ observation vector is being computed:
Simplest observation builder. The object returns observation vectors with 5 identical components,
all equal to the ID of the respective agent.
"""
def __init__(self):
self.observation_space = [5]
def reset(self):
return
def get(self, handle):
observation = handle * np.ones((self.observation_space[0],))
observation = handle * np.ones(5)
return observation
We can pass an instance of our custom observation builder :code:`SimpleObs` to the :code:`RailEnv` creator as follows:
......@@ -85,7 +83,6 @@ Note that this simple strategy fails when multiple agents are present, as each a
super().__init__(max_depth=0)
# We set max_depth=0 in because we only need to look at the current
# position of the agent to decide what direction is shortest.
self.observation_space = [3]
def reset(self):
# Recompute the distance map, if the environment has changed.
......@@ -189,7 +186,6 @@ In contrast to the previous examples we also implement the :code:`def get_many(s
def __init__(self, predictor):
super().__init__(max_depth=0)
self.observation_space = [10]
self.predictor = predictor
def reset(self):
......
......@@ -18,7 +18,6 @@ class SimpleObs(ObservationBuilder):
def __init__(self):
super().__init__()
self.observation_space = [5]
def reset(self):
return
......
......@@ -28,7 +28,6 @@ class SingleAgentNavigationObs(ObservationBuilder):
def __init__(self):
super().__init__()
self.observation_space = [3]
def reset(self):
pass
......
......@@ -28,7 +28,6 @@ class ObservePredictions(ObservationBuilder):
def __init__(self, predictor):
super().__init__()
self.observation_space = [10]
self.predictor = predictor
def reset(self):
......
......@@ -25,7 +25,6 @@ class SingleAgentNavigationObs(ObservationBuilder):
def __init__(self):
super().__init__()
self.observation_space = [3]
def reset(self):
pass
......
......@@ -11,7 +11,6 @@ class Environment:
Derived environments should implement the following attributes:
action_space: tuple with the dimensions of the actions to be passed to the step method
observation_space: tuple with the dimensions of the observations returned by reset and step
Agents are identified by agent ids (handles).
Examples:
......@@ -46,7 +45,6 @@ class Environment:
def __init__(self):
self.action_space = ()
self.observation_space = ()
pass
def reset(self):
......
......@@ -18,13 +18,9 @@ from flatland.core.env import Environment
class ObservationBuilder:
"""
ObservationBuilder base class.
Derived objects must implement and `observation_space` attribute as a tuple with the dimensions of the returned
observations.
"""
def __init__(self):
self.observation_space = ()
self.env = None
def set_env(self, env: Environment):
......
"""
Collection of environment-specific ObservationBuilder.
"""
import pprint
from typing import Optional, List, Dict, T, Tuple
import collections
from typing import Optional, List, Dict, Tuple
import numpy as np
......@@ -15,6 +15,22 @@ from flatland.utils.ordered_set import OrderedSet
class TreeObsForRailEnv(ObservationBuilder):
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'dist_other_target_encountered '
'dist_other_agent_encountered '
'dist_potential_conflict '
'dist_unusable_switch '
'dist_to_next_branch '
'dist_min_to_target '
'num_agents_same_direction '
'num_agents_opposite_direction '
'num_agents_malfunctioning '
'speed_min_fractional '
'childs')
tree_explorted_actions_char = ['L', 'F', 'R', 'B']
"""
TreeObsForRailEnv object.
......@@ -29,24 +45,15 @@ class TreeObsForRailEnv(ObservationBuilder):
super().__init__()
self.max_depth = max_depth
self.observation_dim = 11
# Compute the size of the returned observation vector
size = 0
pow4 = 1
for i in range(self.max_depth + 1):
size += pow4
pow4 *= 4
self.observation_space = [size * self.observation_dim]
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.predictor = predictor
self.location_has_target = None
self.tree_explored_actions = [1, 2, 3, 0]
self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
def reset(self):
self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, List[int]]:
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
......@@ -75,7 +82,7 @@ class TreeObsForRailEnv(ObservationBuilder):
observations[h] = self.get(h)
return observations
def get(self, handle: int = 0) -> List[int]:
def get(self, handle: int = 0) -> Node:
"""
Computes the current observation for agent `handle` in env
......@@ -165,11 +172,18 @@ class TreeObsForRailEnv(ObservationBuilder):
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
# Here information about the agent itself is stored
distance_map = self.env.distance_map.get()
observation = [0, 0, 0, 0, 0, 0, distance_map[(handle, *agent.position, agent.direction)], 0, 0,
agent.malfunction_data['malfunction'], agent.speed_data['speed']]
root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
dist_other_agent_encountered=0, dist_potential_conflict=0,
dist_unusable_switch=0, dist_to_next_branch=0,
dist_min_to_target=distance_map[(handle, *agent.position,
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_data['malfunction'],
speed_min_fractional=agent.speed_data['speed'],
childs={})
visited = OrderedSet()
......@@ -181,28 +195,22 @@ class TreeObsForRailEnv(ObservationBuilder):
if num_transitions == 1:
orientation = np.argmax(possible_transitions)
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
if possible_transitions[branch_direction]:
new_cell = get_new_position(agent.position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
observation = observation + branch_observation
root_node_observation.childs[self.tree_explorted_actions_char[i]] = branch_observation
visited |= branch_visited
else:
# add cells filled with infinity if no transition is possible
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
root_node_observation.childs[self.tree_explorted_actions_char[i]] = -np.inf
self.env.dev_obs_dict[handle] = visited
return observation
def _num_cells_to_fill_in(self, remaining_depth):
"""Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
num_observations = 0
pow4 = 1
for i in range(remaining_depth):
num_observations += pow4
pow4 *= 4
return num_observations * self.observation_dim
return root_node_observation
def _explore_branch(self, handle, position, direction, tot_dist, depth):
"""
......@@ -378,53 +386,35 @@ class TreeObsForRailEnv(ObservationBuilder):
# Modify here to append new / different features for each visited cell!
if last_is_target:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
tot_dist,
0,
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
]
dist_to_next_branch = tot_dist
dist_min_to_target = 0
elif last_is_terminal:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
np.inf,
self.env.distance_map.get()[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
]
dist_to_next_branch = np.inf
dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
else:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
tot_dist,
self.env.distance_map.get()[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
]
dist_to_next_branch = tot_dist
dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
node = TreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered,
dist_other_target_encountered=other_target_encountered,
dist_other_agent_encountered=other_agent_encountered,
dist_potential_conflict=potential_conflict,
dist_unusable_switch=unusable_switch,
dist_to_next_branch=dist_to_next_branch,
dist_min_to_target=dist_min_to_target,
num_agents_same_direction=other_agent_same_direction,
num_agents_opposite_direction=other_agent_opposite_direction,
num_agents_malfunctioning=malfunctioning_agent,
speed_min_fractional=min_fractional_speed,
childs={})
# #############################
# #############################
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# Get the possible transitions
possible_transitions = self.env.rail.get_transitions(*position, direction)
for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
if last_is_dead_end and self.env.rail.get_transition((*position, direction),
(branch_direction + 2) % 4):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
......@@ -435,7 +425,7 @@ class TreeObsForRailEnv(ObservationBuilder):
(branch_direction + 2) % 4,
tot_dist + 1,
depth + 1)
observation = observation + branch_observation
node.childs[self.tree_explorted_actions_char[i]] = branch_observation
if len(branch_visited) != 0:
visited |= branch_visited
elif last_is_switch and possible_transitions[branch_direction]:
......@@ -445,51 +435,45 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_direction,
tot_dist + 1,
depth + 1)
observation = observation + branch_observation
node.childs[self.tree_explorted_actions_char[i]] = branch_observation
if len(branch_visited) != 0:
visited |= branch_visited
else:
# no exploring possible, add just cells with infinity
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
node.childs[self.tree_explorted_actions_char[i]] = -np.inf
return observation, visited
if depth == self.max_depth:
node.childs.clear()
return node, visited
def util_print_obs_subtree(self, tree):
def util_print_obs_subtree(self, tree: Node):
"""
Utility function to pretty-print tree observations returned by this object.
Utility function to print tree observations returned by this object.
"""
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(self.unfold_observation_tree(tree))
self.print_node_features(tree, "root", "")
for direction in self.tree_explorted_actions_char:
self.print_subtree(tree.childs[direction], direction, "\t")
@staticmethod
def print_node_features(node: Node, label, indent):
print(indent, "Direction ", label, ": ", node.dist_own_target_encountered, ", ",
node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ",
node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ",
node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction,
", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional)
def print_subtree(self, node, label, indent):
if node == -np.inf or not node:
print(indent, "Direction ", label, ": -np.inf")
return
def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
"""
Utility function to pretty-print tree observations returned by this object.
"""
if len(tree) < self.observation_dim:
self.print_node_features(node, label, indent)
if not node.childs:
return
depth = 0
tmp = len(tree) / self.observation_dim - 1
pow4 = 4
while tmp > 0:
tmp -= pow4
depth += 1
pow4 *= 4
unfolded = {}
unfolded[''] = tree[0:self.observation_dim]
child_size = (len(tree) - self.observation_dim) // 4
for child in range(4):
child_tree = tree[(self.observation_dim + child * child_size):
(self.observation_dim + (child + 1) * child_size)]
observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
if observation_tree is not None:
if actions_for_display:
label = self.tree_explorted_actions_char[child]
else:
label = self.tree_explored_actions[child]
unfolded[label] = observation_tree
return unfolded
for direction in self.tree_explorted_actions_char:
self.print_subtree(node.childs[direction], direction, indent + "\t")
def set_env(self, env: Environment):
super().set_env(env)
......@@ -508,23 +492,21 @@ class GlobalObsForRailEnv(ObservationBuilder):
- transition map array with dimensions (env.height, env.width, 16),\
assuming 16 bits encoding of transitions.
- A 3D array (map_height, map_width, 4) with
- first channel containing the agents position and direction
- second channel containing the other agents positions and diretion
- third channel containing agent/other agent malfunctions
- fourth channel containing agent/other agent fractional speeds
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
target and the positions of the other agents targets.
- A 3D array (map_height, map_width, 4) wtih
- first channel containing the agents position and direction
- second channel containing the other agents positions and diretions
- third channel containing agent malfunctions
- fourth channel containing agent fractional speeds
"""
def __init__(self):
self.observation_space = ()
super(GlobalObsForRailEnv, self).__init__()
def set_env(self, env: Environment):
super().set_env(env)
self.observation_space = [4, self.env.height, self.env.width]
def reset(self):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
......@@ -535,22 +517,21 @@ class GlobalObsForRailEnv(ObservationBuilder):
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 4))
agents = self.env.agents
agent = agents[handle]
obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - 1
agent_pos = agents[handle].position
obs_agents_state[agent_pos][0] = agents[handle].direction
agent = self.env.agents[handle]
obs_agents_state[agent.position][0] = agent.direction
obs_targets[agent.target][0] = 1
for i in range(len(agents)):
if i != handle: # TODO: handle used as index...?
agent2 = agents[i]
obs_agents_state[agent2.position][1] = agent2.direction
obs_targets[agent2.target][1] = 1
obs_agents_state[agents[i].position][2] = agents[i].malfunction_data['malfunction']
obs_agents_state[agents[i].position][3] = agents[i].speed_data['speed']
for i in range(len(self.env.agents)):
other_agent = self.env.agents[i]
if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction
obs_targets[other_agent.target][1] = 1
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
return self.rail_obs, obs_agents_state, obs_targets
......
......@@ -189,7 +189,6 @@ class RailEnv(Environment):
self.distance_map = DistanceMap(self.agents, self.height, self.width)
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
# Stochastic train malfunctioning parameters
if stochastic_data is not None:
......@@ -300,7 +299,6 @@ class RailEnv(Environment):
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
self.observation_space = self.obs_builder.observation_space # <-- change on reset?
self.distance_map.reset(self.agents, self.rail)
# Return the new observation vectors for each agent
......
......@@ -41,7 +41,9 @@ def test_global_obs():
# If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell
assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)
obs_agents_state = global_obs[0][1]
obs_agents_state = obs_agents_state + 1
assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
def _step_along_shortest_path(env, obs_builder, rail):
......
......@@ -281,9 +281,10 @@ def test_shortest_path_predictor_conflicts(rendering=False):
# get the trees to test
obs_builder: TreeObsForRailEnv = env.obs_builder
pp = pprint.PrettyPrinter(indent=4)
tree_0 = obs_builder.unfold_observation_tree(observations[0])
tree_1 = obs_builder.unfold_observation_tree(observations[1])
pp.pprint(tree_0)
tree_0 = observations[0]
tree_1 = observations[1]
env.obs_builder.util_print_obs_subtree(tree_0)
env.obs_builder.util_print_obs_subtree(tree_1)
# check the expectations
expected_conflicts_0 = [('F', 'R')]
......@@ -292,11 +293,18 @@ def test_shortest_path_predictor_conflicts(rendering=False):
_check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''):
assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt)
def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''):
assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
for a_1 in obs_builder.tree_explorted_actions_char:
conflict = tree_0[a_1][''][8]
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
if tree.childs[a_1] == -np.inf:
assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
continue
else:
conflict = tree.childs[a_1].num_agents_opposite_direction
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
for a_2 in obs_builder.tree_explorted_actions_char:
conflict = tree_0[a_1][a_2][''][8]
assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
if tree.childs[a_1].childs[a_2] == -np.inf:
assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
else:
conflict = tree.childs[a_1].childs[a_2].num_agents_opposite_direction
assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
......@@ -22,7 +22,6 @@ class SingleAgentNavigationObs(ObservationBuilder):
def __init__(self):
super().__init__()
self.observation_space = [3]
def reset(self):
pass
......
import numpy as np
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
def test_get_global_observation():
np.random.seed(1)
number_of_agents = 20
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
}
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=25,
# Number of cities in map (where train stations are)
num_intersections=10,
# Number of intersections (no start / target)
num_trainstations=50, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=4, # Proximity of stations to city center
num_neighb=4,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=GlobalObsForRailEnv())
obs, all_rewards, done, _ = env.step({0: 0})
for i in range(len(env.agents)):
obs_agents_state = obs[i][1]
obs_targets = obs[i][2]
nr_agents = np.count_nonzero(obs_targets[:, :, 0])
nr_agents_other = np.count_nonzero(obs_targets[:, :, 1])
assert nr_agents == 1
assert nr_agents_other == (number_of_agents - 1)
# since the array is initialized with -1 add one in order to used np.count_nonzero
obs_agents_state += 1
obs_agents_state_0 = np.count_nonzero(obs_agents_state[:, :, 0])
obs_agents_state_1 = np.count_nonzero(obs_agents_state[:, :, 1])
obs_agents_state_2 = np.count_nonzero(obs_agents_state[:, :, 2])
obs_agents_state_3 = np.count_nonzero(obs_agents_state[:, :, 3])
assert obs_agents_state_0 == 1
assert obs_agents_state_1 == (number_of_agents - 1)
assert obs_agents_state_2 == number_of_agents
assert obs_agents_state_3 == number_of_agents
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment