Skip to content
Snippets Groups Projects
Commit 3d99e8f3 authored by u214892's avatar u214892
Browse files

#56 bugfix wrong length of observation vector; #47 added two flags whether...

#56 bugfix wrong length of observation vector; #47 added two flags whether there is another agent on the same position in the same or opposite direction
parent 4145c9ea
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,10 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth + 1):
size += pow4
pow4 *= 4
self.observation_space = [size * 6]
self.observation_dim = 7
self.observation_space = [size * self.observation_dim]
self.location_has_agent = {}
self.location_has_agent_direction = {}
def reset(self):
agents = self.env.agents
......@@ -181,8 +184,15 @@ class TreeObsForRailEnv(ObservationBuilder):
#5: minimum distance from node to the agent's target (when landing to the node following the corresponding
branch.
#6: agent direction
#6: agent in the same direction
1 = agent present same direction
(possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction
#7: agent in the opposite drection
1 = agent present other direction than myself (so conflict)
(possible future use: number of other agents in other direction in this branch, i.e. number of conflicts)
0 = no agent present other direction than myself
Missing/padding nodes are filled in with -inf (truncated).
......@@ -195,13 +205,15 @@ class TreeObsForRailEnv(ObservationBuilder):
# Update local lookup table for all agents' positions
self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], agent.direction]
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
root_observation = observation[:]
visited = set()
# Start from the current orientation, and see which transitions are available;
......@@ -212,7 +224,6 @@ class TreeObsForRailEnv(ObservationBuilder):
if num_transitions == 1:
orientation == np.argmax(possible_transitions)
# for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
......@@ -227,7 +238,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
self.env.dev_obs_dict[handle] = visited
return observation
......@@ -275,7 +286,6 @@ class TreeObsForRailEnv(ObservationBuilder):
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
# if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
if np.array_equal(position, self.env.agents[handle].target):
last_isTarget = True
break
......@@ -297,6 +307,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if not last_isDeadEnd:
# Keep walking through the tree along `direction'
exploring = True
# convert one-hot encoding to 0,1,2,3
direction = np.argmax(cell_transitions)
position = self._new_position(position, direction)
num_steps += 1
......@@ -321,36 +332,53 @@ class TreeObsForRailEnv(ObservationBuilder):
# #############################
# Modify here to append new / different features for each visited cell!
"""
other_agent_same_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
other_agent_opposite_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
if last_isTarget:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
0,
direction]
other_agent_same_direction,
other_agent_opposite_direction
]
elif last_isTerminal:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
other_target_encountered,
other_agent_encountered,
np.inf,
np.inf,
direction]
other_agent_same_direction,
other_agent_opposite_direction
]
else:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction],
direction]
other_agent_same_direction,
other_agent_opposite_direction
]
"""
other_agent_same_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
other_agent_opposite_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
if last_isTarget:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
0,
direction
other_agent_same_direction,
other_agent_opposite_direction
]
elif last_isTerminal:
......@@ -359,7 +387,8 @@ class TreeObsForRailEnv(ObservationBuilder):
other_agent_encountered,
np.inf,
np.inf,
direction
other_agent_same_direction,
other_agent_opposite_direction
]
else:
observation = [0,
......@@ -367,7 +396,8 @@ class TreeObsForRailEnv(ObservationBuilder):
other_agent_encountered,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction],
direction
other_agent_same_direction,
other_agent_opposite_direction
]
# #############################
# #############################
......@@ -407,8 +437,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth - depth):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = \
observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
return observation, visited
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment