Skip to content
Snippets Groups Projects
Commit e0a10143 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated tree observation

now detecting malfunctioning agents
now detecting slowest fractional speed
parent a42db14d
No related branches found
No related tags found
No related merge requests found
......@@ -25,14 +25,13 @@ class TreeObsForRailEnv(ObservationBuilder):
def __init__(self, max_depth, predictor=None):
super().__init__()
self.max_depth = max_depth
self.observation_dim = 9
self.observation_dim = 11
# Compute the size of the returned observation vector
size = 0
pow4 = 1
for i in range(self.max_depth + 1):
size += pow4
pow4 *= 4
self.observation_dim = 9
self.observation_space = [size * self.observation_dim]
self.location_has_agent = {}
self.location_has_agent_direction = {}
......@@ -280,7 +279,9 @@ class TreeObsForRailEnv(ObservationBuilder):
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
# Here information about the agent itself is stored
observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0,
agent.malfunction_data['malfunction'], agent.speed_data['speed']]
visited = set()
......@@ -357,6 +358,10 @@ class TreeObsForRailEnv(ObservationBuilder):
if tot_dist < other_agent_encountered:
other_agent_encountered = tot_dist
# Check if any of the observed agents is malfunctioning, store agent with longest duration left
if self.location_has_agent_malfunction[position] > malfunctioning_agent:
malfunctioning_agent = self.location_has_agent_malfunction[position]
if self.location_has_agent_direction[position] == direction:
# Cummulate the number of agents on branch with same direction
other_agent_same_direction += 1
......@@ -365,6 +370,7 @@ class TreeObsForRailEnv(ObservationBuilder):
current_fractional_speed = self.location_has_agent_speed[position]
if current_fractional_speed < min_fractional_speed:
min_fractional_speed = current_fractional_speed
if self.location_has_agent_direction[position] != direction:
# Cummulate the number of agents on branch with other direction
other_agent_opposite_direction += 1
......@@ -492,7 +498,9 @@ class TreeObsForRailEnv(ObservationBuilder):
tot_dist,
0,
other_agent_same_direction,
other_agent_opposite_direction
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
]
elif last_is_terminal:
......@@ -504,7 +512,9 @@ class TreeObsForRailEnv(ObservationBuilder):
np.inf,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
]
else:
......@@ -517,6 +527,8 @@ class TreeObsForRailEnv(ObservationBuilder):
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
]
# #############################
# #############################
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment