Skip to content
Snippets Groups Projects
Commit aadff790 authored by u214892's avatar u214892
Browse files

47 agent directions in observation

parent f5022411
No related branches found
No related tags found
No related merge requests found
...@@ -7,13 +7,14 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and ...@@ -7,13 +7,14 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and
+ Get() is called whenever an observation has to be computed, potentially for each agent independently in + Get() is called whenever an observation has to be computed, potentially for each agent independently in
case of multi-agent environments. case of multi-agent environments.
""" """
import numpy as np
class ObservationBuilder: class ObservationBuilder:
""" """
ObservationBuilder base class. ObservationBuilder base class.
Derived objects must implement and `observation_space' attribute as a tuple with the dimensuions of the returned Derived objects must implement and `observation_space' attribute as a tuple with the dimensions of the returned
observations. observations.
""" """
...@@ -45,3 +46,9 @@ class ObservationBuilder: ...@@ -45,3 +46,9 @@ class ObservationBuilder:
An observation structure, specific to the corresponding environment. An observation structure, specific to the corresponding environment.
""" """
raise NotImplementedError() raise NotImplementedError()
def _get_one_hot_for_agent_direction(self, agent):
"""Retuns the agent's direction to one-hot encoding."""
direction = np.zeros(4)
direction[agent.direction] = 1
return direction
""" """
Collection of environment-specific ObservationBuilder. Collection of environment-specific ObservationBuilder.
""" """
import numpy as np
from collections import deque from collections import deque
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_observation_builder import ObservationBuilder
...@@ -22,10 +23,10 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -22,10 +23,10 @@ class TreeObsForRailEnv(ObservationBuilder):
# Compute the size of the returned observation vector # Compute the size of the returned observation vector
size = 0 size = 0
pow4 = 1 pow4 = 1
for i in range(self.max_depth+1): for i in range(self.max_depth + 1):
size += pow4 size += pow4
pow4 *= 4 pow4 *= 4
self.observation_space = [size * 5] self.observation_space = [size * 6]
def reset(self): def reset(self):
agents = self.env.agents agents = self.env.agents
...@@ -186,6 +187,10 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -186,6 +187,10 @@ class TreeObsForRailEnv(ObservationBuilder):
#5: minimum distance from node to the agent's target (when landing to the node following the corresponding #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
branch. branch.
#6: agent direction
Missing/padding nodes are filled in with -inf (truncated). Missing/padding nodes are filled in with -inf (truncated).
Missing values in present node are filled in with +inf (truncated). Missing values in present node are filled in with +inf (truncated).
...@@ -202,13 +207,10 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -202,13 +207,10 @@ class TreeObsForRailEnv(ObservationBuilder):
if handle > len(self.env.agents): if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index agent = self.env.agents[handle] # TODO: handle being treated as index
# position = self.env.agents_position[handle]
# orientation = self.env.agents_direction[handle]
possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
num_transitions = np.count_nonzero(possible_transitions) num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position # Root node - current position
# observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], agent.direction]
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]]
root_observation = observation[:] root_observation = observation[:]
visited = set() visited = set()
# Start from the current orientation, and see which transitions are available; # Start from the current orientation, and see which transitions are available;
...@@ -337,40 +339,49 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -337,40 +339,49 @@ class TreeObsForRailEnv(ObservationBuilder):
1 if other_target_encountered else 0, 1 if other_target_encountered else 0,
1 if other_agent_encountered else 0, 1 if other_agent_encountered else 0,
root_observation[3] + num_steps, root_observation[3] + num_steps,
0] 0,
direction]
elif last_isTerminal: elif last_isTerminal:
observation = [0, observation = [0,
1 if other_target_encountered else 0, 1 if other_target_encountered else 0,
1 if other_agent_encountered else 0, 1 if other_agent_encountered else 0,
np.inf, np.inf,
np.inf] np.inf,
direction]
else: else:
observation = [0, observation = [0,
1 if other_target_encountered else 0, 1 if other_target_encountered else 0,
1 if other_agent_encountered else 0, 1 if other_agent_encountered else 0,
root_observation[3] + num_steps, root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction]] self.distance_map[handle, position[0], position[1], direction],
direction]
""" """
if last_isTarget: if last_isTarget:
observation = [0, observation = [0,
other_target_encountered, other_target_encountered,
other_agent_encountered, other_agent_encountered,
root_observation[3] + num_steps, root_observation[3] + num_steps,
0] 0,
direction
]
elif last_isTerminal: elif last_isTerminal:
observation = [0, observation = [0,
other_target_encountered, other_target_encountered,
other_agent_encountered, other_agent_encountered,
np.inf, np.inf,
np.inf] np.inf,
direction
]
else: else:
observation = [0, observation = [0,
other_target_encountered, other_target_encountered,
other_agent_encountered, other_agent_encountered,
root_observation[3] + num_steps, root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction]] self.distance_map[handle, position[0], position[1], direction],
direction
]
# ############################# # #############################
# ############################# # #############################
...@@ -409,7 +420,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -409,7 +420,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth - depth): for i in range(self.max_depth - depth):
num_cells_to_fill_in += pow4 num_cells_to_fill_in += pow4
pow4 *= 4 pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
return observation, visited return observation, visited
...@@ -532,7 +543,9 @@ class GlobalObsForRailEnv(ObservationBuilder): ...@@ -532,7 +543,9 @@ class GlobalObsForRailEnv(ObservationBuilder):
obs_agents_state[agent2.position][4 + agent2.direction] = 1 obs_agents_state[agent2.position][4 + agent2.direction] = 1
obs_targets[agent2.target][1] += 1 obs_targets[agent2.target][1] += 1
return self.rail_obs, obs_agents_state, obs_targets direction = self._get_one_hot_for_agent_direction(agent)
return self.rail_obs, obs_agents_state, obs_targets, direction
class GlobalObsForRailEnvDirectionDependent(ObservationBuilder): class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
...@@ -542,13 +555,15 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder): ...@@ -542,13 +555,15 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
- transition map array with dimensions (env.height, env.width, 16), - transition map array with dimensions (env.height, env.width, 16),
assuming 16 bits encoding of transitions, flipped in the direction of the agent assuming 16 bits encoding of transitions, flipped in the direction of the agent
(the agent is always heding north on the flipped view). (the agent is always heading north on the flipped view).
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
target and the positions of the other agents targets, also flipped depending on the agent's direction. target and the positions of the other agents targets, also flipped depending on the agent's direction.
- A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
agents at their position coordinates, and the last channel containing the position of the given agent. agents at their position coordinates, and the last channel containing the position of the given agent.
- A 4 elements array with one hot encoding of the direction.
""" """
def __init__(self): def __init__(self):
...@@ -603,7 +618,9 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder): ...@@ -603,7 +618,9 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1 obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1
obs_targets[agent2.target][1] += 1 obs_targets[agent2.target][1] += 1
return rail_obs, obs_agents_state, obs_targets direction = self._get_one_hot_for_agent_direction(agent)
return rail_obs, obs_agents_state, obs_targets, direction
class LocalObsForRailEnv(ObservationBuilder): class LocalObsForRailEnv(ObservationBuilder):
...@@ -635,8 +652,8 @@ class LocalObsForRailEnv(ObservationBuilder): ...@@ -635,8 +652,8 @@ class LocalObsForRailEnv(ObservationBuilder):
# We build the transition map with a view_radius empty cells expansion on each side. # We build the transition map with a view_radius empty cells expansion on each side.
# This helps to collect the local transition map view when the agent is close to a border. # This helps to collect the local transition map view when the agent is close to a border.
self.rail_obs = np.zeros((self.env.height + 2*self.view_radius, self.rail_obs = np.zeros((self.env.height + 2 * self.view_radius,
self.env.width + 2*self.view_radius, 16)) self.env.width + 2 * self.view_radius, 16))
for i in range(self.env.height): for i in range(self.env.height):
for j in range(self.env.width): for j in range(self.env.width):
bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]] bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
...@@ -654,12 +671,12 @@ class LocalObsForRailEnv(ObservationBuilder): ...@@ -654,12 +671,12 @@ class LocalObsForRailEnv(ObservationBuilder):
# top_offset = max(0, agent.position[0] - 1 - self.view_radius) # top_offset = max(0, agent.position[0] - 1 - self.view_radius)
# bottom_offset = min(0, agent.position[0] + 1 + self.view_radius) # bottom_offset = min(0, agent.position[0] + 1 + self.view_radius)
local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0]+2*self.view_radius + 1, local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0] + 2 * self.view_radius + 1,
agent.position[1]:agent.position[1]+2*self.view_radius + 1] agent.position[1]:agent.position[1] + 2 * self.view_radius + 1]
obs_map_state = np.zeros((2*self.view_radius + 1, 2*self.view_radius + 1, 2)) obs_map_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 2))
obs_other_agents_state = np.zeros((2*self.view_radius + 1, 2*self.view_radius + 1, 4)) obs_other_agents_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 4))
def relative_pos(pos): def relative_pos(pos):
return [agent.position[0] - pos[0], agent.position[1] - pos[1]] return [agent.position[0] - pos[0], agent.position[1] - pos[1]]
...@@ -684,15 +701,11 @@ class LocalObsForRailEnv(ObservationBuilder): ...@@ -684,15 +701,11 @@ class LocalObsForRailEnv(ObservationBuilder):
if is_in(target_rel_pos_2): if is_in(target_rel_pos_2):
obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1 obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1
direction = np.zeros(4) direction = self._get_one_hot_for_agent_direction(agent)
direction[agent.direction] = 1
return local_rail_obs, obs_map_state, obs_other_agents_state, direction return local_rail_obs, obs_map_state, obs_other_agents_state, direction
# class LocalObsForRailEnvImproved(ObservationBuilder): # class LocalObsForRailEnvImproved(ObservationBuilder):
# """ # """
# Returns a local observation around the given agent # Returns a local observation around the given agent
# """ # """
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment