Commit 61b75d2f authored by u214892's avatar u214892
Browse files

observations only for active agents

parent 7b49965c
Pipeline #2330 failed with stages
in 12 minutes and 2 seconds
......@@ -11,11 +11,20 @@ from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
from flatland.utils.ordered_set import OrderedSet
class TreeObsForRailEnv(ObservationBuilder):
"""
TreeObsForRailEnv object.
This object returns observation vectors for agents in the RailEnv environment.
The information is local to each agent and exploits the graph structure of the rail
network to simplify the representation of the state of the environment for each agent.
For details about the features in the tree observation see the get() function.
"""
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'dist_other_target_encountered '
'dist_other_agent_encountered '
......@@ -27,19 +36,10 @@ class TreeObsForRailEnv(ObservationBuilder):
'num_agents_opposite_direction '
'num_agents_malfunctioning '
'speed_min_fractional '
'num_agents_ready_to_depart '
'childs')
tree_explorted_actions_char = ['L', 'F', 'R', 'B']
"""
TreeObsForRailEnv object.
This object returns observation vectors for agents in the RailEnv environment.
The information is local to each agent and exploits the graph structure of the rail
network to simplify the representation of the state of the environment for each agent.
For details about the features in the tree observation see the get() function.
"""
tree_explored_actions_char = ['L', 'F', 'R', 'B']
def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
super().__init__()
......@@ -152,6 +152,8 @@ class TreeObsForRailEnv(ObservationBuilder):
1 if no agent is observed
min_fractional speed otherwise
#12:
number of agents ready to depart but no yet active
Missing/padding nodes are filled in with -inf (truncated).
Missing values in present node are filled in with +inf (truncated).
......@@ -163,6 +165,11 @@ class TreeObsForRailEnv(ObservationBuilder):
# Update local lookup table for all agents' positions
self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
self.location_has_agent_ready_to_depart = {}
for agent in self.env.agents:
if agent.status == RailAgentStatus.READY_TO_DEPART:
self.location_has_agent_ready_to_depart = \
self.location_has_agent_ready_to_depart.get(tuple(agent.initial_position), 0) + 1
self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
......@@ -185,6 +192,7 @@ class TreeObsForRailEnv(ObservationBuilder):
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_data['malfunction'],
speed_min_fractional=agent.speed_data['speed'],
num_agents_ready_to_depart=0,
childs={})
visited = OrderedSet()
......@@ -204,12 +212,12 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
root_node_observation.childs[self.tree_explorted_actions_char[i]] = branch_observation
root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation
visited |= branch_visited
else:
# add cells filled with infinity if no transition is possible
root_node_observation.childs[self.tree_explorted_actions_char[i]] = -np.inf
root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf
self.env.dev_obs_dict[handle] = visited
return root_node_observation
......@@ -247,6 +255,7 @@ class TreeObsForRailEnv(ObservationBuilder):
malfunctioning_agent = 0
min_fractional_speed = 1.
num_steps = 1
other_agent_ready_to_depart_encountered = 0
while exploring:
# #############################
# #############################
......@@ -260,6 +269,8 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.location_has_agent_malfunction[position] > malfunctioning_agent:
malfunctioning_agent = self.location_has_agent_malfunction[position]
other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0)
if self.location_has_agent_direction[position] == direction:
# Cummulate the number of agents on branch with same direction
other_agent_same_direction += 1
......@@ -408,6 +419,7 @@ class TreeObsForRailEnv(ObservationBuilder):
num_agents_opposite_direction=other_agent_opposite_direction,
num_agents_malfunctioning=malfunctioning_agent,
speed_min_fractional=min_fractional_speed,
num_agents_ready_to_depart=other_agent_ready_to_depart_encountered,
childs={})
# #############################
......@@ -427,7 +439,7 @@ class TreeObsForRailEnv(ObservationBuilder):
(branch_direction + 2) % 4,
tot_dist + 1,
depth + 1)
node.childs[self.tree_explorted_actions_char[i]] = branch_observation
node.childs[self.tree_explored_actions_char[i]] = branch_observation
if len(branch_visited) != 0:
visited |= branch_visited
elif last_is_switch and possible_transitions[branch_direction]:
......@@ -437,12 +449,12 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_direction,
tot_dist + 1,
depth + 1)
node.childs[self.tree_explorted_actions_char[i]] = branch_observation
node.childs[self.tree_explored_actions_char[i]] = branch_observation
if len(branch_visited) != 0:
visited |= branch_visited
else:
# no exploring possible, add just cells with infinity
node.childs[self.tree_explorted_actions_char[i]] = -np.inf
node.childs[self.tree_explored_actions_char[i]] = -np.inf
if depth == self.max_depth:
node.childs.clear()
......@@ -453,7 +465,7 @@ class TreeObsForRailEnv(ObservationBuilder):
Utility function to print tree observations returned by this object.
"""
self.print_node_features(tree, "root", "")
for direction in self.tree_explorted_actions_char:
for direction in self.tree_explored_actions_char:
self.print_subtree(tree.childs[direction], direction, "\t")
@staticmethod
......@@ -462,7 +474,8 @@ class TreeObsForRailEnv(ObservationBuilder):
node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ",
node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ",
node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction,
", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional)
", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional, ", ",
node.num_agents_ready_to_depart)
def print_subtree(self, node, label, indent):
if node == -np.inf or not node:
......@@ -474,7 +487,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if not node.childs:
return
for direction in self.tree_explorted_actions_char:
for direction in self.tree_explored_actions_char:
self.print_subtree(node.childs[direction], direction, indent + "\t")
def set_env(self, env: Environment):
......@@ -499,6 +512,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
- second channel containing the other agents positions and diretion
- third channel containing agent/other agent malfunctions
- fourth channel containing agent/other agent fractional speeds
- fifth channel containing number of other agents ready to depart
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
target and the positions of the other agents targets.
......@@ -521,17 +535,20 @@ class GlobalObsForRailEnv(ObservationBuilder):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - 1
obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
agent = self.env.agents[handle]
obs_agents_state[agent.position][0] = agent.direction
obs_targets[agent.target][0] = 1
for i in range(len(self.env.agents)):
other_agent = self.env.agents[i]
other_agent: EnvAgent = self.env.agents[i]
if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction
obs_targets[other_agent.target][1] = 1
if other_agent.status == RailAgentStatus.READY_TO_DEPART:
obs_agents_state[other_agent.initial_position] += 1
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
......
......@@ -303,14 +303,14 @@ def test_shortest_path_predictor_conflicts(rendering=False):
def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''):
assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
for a_1 in obs_builder.tree_explorted_actions_char:
for a_1 in obs_builder.tree_explored_actions_char:
if tree.childs[a_1] == -np.inf:
assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
continue
else:
conflict = tree.childs[a_1].num_agents_opposite_direction
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
for a_2 in obs_builder.tree_explorted_actions_char:
for a_2 in obs_builder.tree_explored_actions_char:
if tree.childs[a_1].childs[a_2] == -np.inf:
assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment