From e0a101437ab227276e4d3f70ddd6e14504f76514 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 1 Sep 2019 08:10:05 -0400 Subject: [PATCH] updated tree observation now detecting malfunctioning agents now detecting slowest fractional speed --- flatland/envs/observations.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 1fc7a400..506afd68 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -25,14 +25,13 @@ class TreeObsForRailEnv(ObservationBuilder): def __init__(self, max_depth, predictor=None): super().__init__() self.max_depth = max_depth - self.observation_dim = 9 + self.observation_dim = 11 # Compute the size of the returned observation vector size = 0 pow4 = 1 for i in range(self.max_depth + 1): size += pow4 pow4 *= 4 - self.observation_dim = 9 self.observation_space = [size * self.observation_dim] self.location_has_agent = {} self.location_has_agent_direction = {} @@ -280,7 +279,9 @@ class TreeObsForRailEnv(ObservationBuilder): num_transitions = np.count_nonzero(possible_transitions) # Root node - current position - observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0] + # Here information about the agent itself is stored + observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, + agent.malfunction_data['malfunction'], agent.speed_data['speed']] visited = set() @@ -357,6 +358,10 @@ class TreeObsForRailEnv(ObservationBuilder): if tot_dist < other_agent_encountered: other_agent_encountered = tot_dist + # Check if any of the observed agents is malfunctioning, store agent with longest duration left + if self.location_has_agent_malfunction[position] > malfunctioning_agent: + malfunctioning_agent = self.location_has_agent_malfunction[position] + if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction other_agent_same_direction += 1 @@ -365,6 +370,7 @@ class TreeObsForRailEnv(ObservationBuilder): current_fractional_speed = self.location_has_agent_speed[position] if current_fractional_speed < min_fractional_speed: min_fractional_speed = current_fractional_speed + if self.location_has_agent_direction[position] != direction: # Cummulate the number of agents on branch with other direction other_agent_opposite_direction += 1 @@ -492,7 +498,9 @@ class TreeObsForRailEnv(ObservationBuilder): tot_dist, 0, other_agent_same_direction, - other_agent_opposite_direction + other_agent_opposite_direction, + malfunctioning_agent, + min_fractional_speed ] elif last_is_terminal: @@ -504,7 +512,9 @@ class TreeObsForRailEnv(ObservationBuilder): np.inf, self.distance_map[handle, position[0], position[1], direction], other_agent_same_direction, - other_agent_opposite_direction + other_agent_opposite_direction, + malfunctioning_agent, + min_fractional_speed ] else: @@ -517,6 +527,8 @@ class TreeObsForRailEnv(ObservationBuilder): self.distance_map[handle, position[0], position[1], direction], other_agent_same_direction, other_agent_opposite_direction, + malfunctioning_agent, + min_fractional_speed ] # ############################# # ############################# -- GitLab