Skip to content
Snippets Groups Projects
Commit 2de92fce authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '47-agent-directions-in-observations-cont' into 'master'

#56 bugfix wrong length of observation vector; #47 added two flags whether...

Closes #47 and #56

See merge request flatland/flatland!52
parents fd755735 f230fa85
No related branches found
No related tags found
No related merge requests found
...@@ -26,7 +26,10 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -26,7 +26,10 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth + 1): for i in range(self.max_depth + 1):
size += pow4 size += pow4
pow4 *= 4 pow4 *= 4
self.observation_space = [size * 6] self.observation_dim = 7
self.observation_space = [size * self.observation_dim]
self.location_has_agent = {}
self.location_has_agent_direction = {}
def reset(self): def reset(self):
agents = self.env.agents agents = self.env.agents
...@@ -181,8 +184,15 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -181,8 +184,15 @@ class TreeObsForRailEnv(ObservationBuilder):
#5: minimum distance from node to the agent's target (when landing to the node following the corresponding #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
branch. branch.
#6: agent direction #6: agent in the same direction
1 = agent present same direction
(possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction
#7: agent in the opposite drection
1 = agent present other direction than myself (so conflict)
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself
Missing/padding nodes are filled in with -inf (truncated). Missing/padding nodes are filled in with -inf (truncated).
...@@ -195,13 +205,15 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -195,13 +205,15 @@ class TreeObsForRailEnv(ObservationBuilder):
# Update local lookup table for all agents' positions # Update local lookup table for all agents' positions
self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents} self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
if handle > len(self.env.agents): if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index agent = self.env.agents[handle] # TODO: handle being treated as index
possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
num_transitions = np.count_nonzero(possible_transitions) num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position # Root node - current position
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], agent.direction] observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
root_observation = observation[:] root_observation = observation[:]
visited = set() visited = set()
# Start from the current orientation, and see which transitions are available; # Start from the current orientation, and see which transitions are available;
...@@ -212,7 +224,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -212,7 +224,6 @@ class TreeObsForRailEnv(ObservationBuilder):
if num_transitions == 1: if num_transitions == 1:
orientation == np.argmax(possible_transitions) orientation == np.argmax(possible_transitions)
# for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
if possible_transitions[branch_direction]: if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction) new_cell = self._new_position(agent.position, branch_direction)
...@@ -227,7 +238,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -227,7 +238,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth): for i in range(self.max_depth):
num_cells_to_fill_in += pow4 num_cells_to_fill_in += pow4
pow4 *= 4 pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
self.env.dev_obs_dict[handle] = visited self.env.dev_obs_dict[handle] = visited
return observation return observation
...@@ -275,7 +286,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -275,7 +286,6 @@ class TreeObsForRailEnv(ObservationBuilder):
visited.add((position[0], position[1], direction)) visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible. # If the target node is encountered, pick that as node. Also, no further branching is possible.
# if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
if np.array_equal(position, self.env.agents[handle].target): if np.array_equal(position, self.env.agents[handle].target):
last_isTarget = True last_isTarget = True
break break
...@@ -297,6 +307,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -297,6 +307,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if not last_isDeadEnd: if not last_isDeadEnd:
# Keep walking through the tree along `direction' # Keep walking through the tree along `direction'
exploring = True exploring = True
# convert one-hot encoding to 0,1,2,3
direction = np.argmax(cell_transitions) direction = np.argmax(cell_transitions)
position = self._new_position(position, direction) position = self._new_position(position, direction)
num_steps += 1 num_steps += 1
...@@ -321,36 +332,53 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -321,36 +332,53 @@ class TreeObsForRailEnv(ObservationBuilder):
# ############################# # #############################
# Modify here to append new / different features for each visited cell! # Modify here to append new / different features for each visited cell!
""" """
other_agent_same_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
other_agent_opposite_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
if last_isTarget: if last_isTarget:
observation = [0, observation = [0,
1 if other_target_encountered else 0, other_target_encountered,
1 if other_agent_encountered else 0, other_agent_encountered,
root_observation[3] + num_steps, root_observation[3] + num_steps,
0, 0,
direction] other_agent_same_direction,
other_agent_opposite_direction
]
elif last_isTerminal: elif last_isTerminal:
observation = [0, observation = [0,
1 if other_target_encountered else 0, other_target_encountered,
1 if other_agent_encountered else 0, other_agent_encountered,
np.inf, np.inf,
np.inf, np.inf,
direction] other_agent_same_direction,
other_agent_opposite_direction
]
else: else:
observation = [0, observation = [0,
1 if other_target_encountered else 0, other_target_encountered,
1 if other_agent_encountered else 0, other_agent_encountered,
root_observation[3] + num_steps, root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction], self.distance_map[handle, position[0], position[1], direction],
direction] other_agent_same_direction,
other_agent_opposite_direction
]
""" """
other_agent_same_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
other_agent_opposite_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
if last_isTarget: if last_isTarget:
observation = [0, observation = [0,
other_target_encountered, other_target_encountered,
other_agent_encountered, other_agent_encountered,
root_observation[3] + num_steps, root_observation[3] + num_steps,
0, 0,
direction other_agent_same_direction,
other_agent_opposite_direction
] ]
elif last_isTerminal: elif last_isTerminal:
...@@ -359,7 +387,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -359,7 +387,8 @@ class TreeObsForRailEnv(ObservationBuilder):
other_agent_encountered, other_agent_encountered,
np.inf, np.inf,
np.inf, np.inf,
direction other_agent_same_direction,
other_agent_opposite_direction
] ]
else: else:
observation = [0, observation = [0,
...@@ -367,7 +396,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -367,7 +396,8 @@ class TreeObsForRailEnv(ObservationBuilder):
other_agent_encountered, other_agent_encountered,
root_observation[3] + num_steps, root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction], self.distance_map[handle, position[0], position[1], direction],
direction other_agent_same_direction,
other_agent_opposite_direction
] ]
# ############################# # #############################
# ############################# # #############################
...@@ -407,8 +437,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -407,8 +437,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth - depth): for i in range(self.max_depth - depth):
num_cells_to_fill_in += pow4 num_cells_to_fill_in += pow4
pow4 *= 4 pow4 *= 4
observation = \ observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
return observation, visited return observation, visited
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment