diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 1fc7a4008e7da703e4bc005c07dc40eed67b9e2d..506afd686a0c5f5a5b35f2081698aa42f745e04d 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -25,14 +25,13 @@ class TreeObsForRailEnv(ObservationBuilder): def __init__(self, max_depth, predictor=None): super().__init__() self.max_depth = max_depth - self.observation_dim = 9 + self.observation_dim = 11 # Compute the size of the returned observation vector size = 0 pow4 = 1 for i in range(self.max_depth + 1): size += pow4 pow4 *= 4 - self.observation_dim = 9 self.observation_space = [size * self.observation_dim] self.location_has_agent = {} self.location_has_agent_direction = {} @@ -280,7 +279,9 @@ class TreeObsForRailEnv(ObservationBuilder): num_transitions = np.count_nonzero(possible_transitions) # Root node - current position - observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0] + # Here information about the agent itself is stored + observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, + agent.malfunction_data['malfunction'], agent.speed_data['speed']] visited = set() @@ -357,6 +358,10 @@ class TreeObsForRailEnv(ObservationBuilder): if tot_dist < other_agent_encountered: other_agent_encountered = tot_dist + # Check if any of the observed agents is malfunctioning, store agent with longest duration left + if self.location_has_agent_malfunction[position] > malfunctioning_agent: + malfunctioning_agent = self.location_has_agent_malfunction[position] + if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction other_agent_same_direction += 1 @@ -365,6 +370,7 @@ class TreeObsForRailEnv(ObservationBuilder): current_fractional_speed = self.location_has_agent_speed[position] if current_fractional_speed < min_fractional_speed: min_fractional_speed = current_fractional_speed + if self.location_has_agent_direction[position] != direction: # Cummulate the number of agents on branch with other direction other_agent_opposite_direction += 1 @@ -492,7 +498,9 @@ class TreeObsForRailEnv(ObservationBuilder): tot_dist, 0, other_agent_same_direction, - other_agent_opposite_direction + other_agent_opposite_direction, + malfunctioning_agent, + min_fractional_speed ] elif last_is_terminal: @@ -504,7 +512,9 @@ class TreeObsForRailEnv(ObservationBuilder): np.inf, self.distance_map[handle, position[0], position[1], direction], other_agent_same_direction, - other_agent_opposite_direction + other_agent_opposite_direction, + malfunctioning_agent, + min_fractional_speed ] else: @@ -517,6 +527,8 @@ class TreeObsForRailEnv(ObservationBuilder): self.distance_map[handle, position[0], position[1], direction], other_agent_same_direction, other_agent_opposite_direction, + malfunctioning_agent, + min_fractional_speed ] # ############################# # #############################