Skip to content
Snippets Groups Projects
Commit 2de92fce authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '47-agent-directions-in-observations-cont' into 'master'

#56 bugfix wrong length of observation vector; #47 added two flags whether...

Closes #47 and #56

See merge request flatland/flatland!52
parents fd755735 f230fa85
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,10 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth + 1):
size += pow4
pow4 *= 4
self.observation_space = [size * 6]
self.observation_dim = 7
self.observation_space = [size * self.observation_dim]
self.location_has_agent = {}
self.location_has_agent_direction = {}
def reset(self):
agents = self.env.agents
......@@ -181,8 +184,15 @@ class TreeObsForRailEnv(ObservationBuilder):
#5: minimum distance from node to the agent's target (when landing to the node following the corresponding
branch.
#6: agent direction
#6: agent in the same direction
1 = agent present same direction
(possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction
#7: agent in the opposite drection
1 = agent present other direction than myself (so conflict)
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself
Missing/padding nodes are filled in with -inf (truncated).
......@@ -195,13 +205,15 @@ class TreeObsForRailEnv(ObservationBuilder):
# Update local lookup table for all agents' positions
self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], agent.direction]
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
root_observation = observation[:]
visited = set()
# Start from the current orientation, and see which transitions are available;
......@@ -212,7 +224,6 @@ class TreeObsForRailEnv(ObservationBuilder):
if num_transitions == 1:
orientation == np.argmax(possible_transitions)
# for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
......@@ -227,7 +238,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
self.env.dev_obs_dict[handle] = visited
return observation
......@@ -275,7 +286,6 @@ class TreeObsForRailEnv(ObservationBuilder):
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
# if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
if np.array_equal(position, self.env.agents[handle].target):
last_isTarget = True
break
......@@ -297,6 +307,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if not last_isDeadEnd:
# Keep walking through the tree along `direction'
exploring = True
# convert one-hot encoding to 0,1,2,3
direction = np.argmax(cell_transitions)
position = self._new_position(position, direction)
num_steps += 1
......@@ -321,36 +332,53 @@ class TreeObsForRailEnv(ObservationBuilder):
# #############################
# Modify here to append new / different features for each visited cell!
"""
other_agent_same_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
other_agent_opposite_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
if last_isTarget:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
0,
direction]
other_agent_same_direction,
other_agent_opposite_direction
]
elif last_isTerminal:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
other_target_encountered,
other_agent_encountered,
np.inf,
np.inf,
direction]
other_agent_same_direction,
other_agent_opposite_direction
]
else:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction],
direction]
other_agent_same_direction,
other_agent_opposite_direction
]
"""
other_agent_same_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
other_agent_opposite_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
if last_isTarget:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
0,
direction
other_agent_same_direction,
other_agent_opposite_direction
]
elif last_isTerminal:
......@@ -359,7 +387,8 @@ class TreeObsForRailEnv(ObservationBuilder):
other_agent_encountered,
np.inf,
np.inf,
direction
other_agent_same_direction,
other_agent_opposite_direction
]
else:
observation = [0,
......@@ -367,7 +396,8 @@ class TreeObsForRailEnv(ObservationBuilder):
other_agent_encountered,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction],
direction
other_agent_same_direction,
other_agent_opposite_direction
]
# #############################
# #############################
......@@ -407,8 +437,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth - depth):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = \
observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
return observation, visited
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment