Skip to content
Snippets Groups Projects
Commit aadff790 authored by u214892's avatar u214892
Browse files

47 agent directions in observation

parent f5022411
No related branches found
No related tags found
No related merge requests found
......@@ -7,13 +7,14 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and
+ Get() is called whenever an observation has to be computed, potentially for each agent independently in
case of multi-agent environments.
"""
import numpy as np
class ObservationBuilder:
"""
ObservationBuilder base class.
Derived objects must implement and `observation_space' attribute as a tuple with the dimensuions of the returned
Derived objects must implement and `observation_space' attribute as a tuple with the dimensions of the returned
observations.
"""
......@@ -45,3 +46,9 @@ class ObservationBuilder:
An observation structure, specific to the corresponding environment.
"""
raise NotImplementedError()
def _get_one_hot_for_agent_direction(self, agent):
"""Retuns the agent's direction to one-hot encoding."""
direction = np.zeros(4)
direction[agent.direction] = 1
return direction
"""
Collection of environment-specific ObservationBuilder.
"""
import numpy as np
from collections import deque
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
......@@ -22,10 +23,10 @@ class TreeObsForRailEnv(ObservationBuilder):
# Compute the size of the returned observation vector
size = 0
pow4 = 1
for i in range(self.max_depth+1):
for i in range(self.max_depth + 1):
size += pow4
pow4 *= 4
self.observation_space = [size * 5]
self.observation_space = [size * 6]
def reset(self):
agents = self.env.agents
......@@ -186,6 +187,10 @@ class TreeObsForRailEnv(ObservationBuilder):
#5: minimum distance from node to the agent's target (when landing to the node following the corresponding
branch.
#6: agent direction
Missing/padding nodes are filled in with -inf (truncated).
Missing values in present node are filled in with +inf (truncated).
......@@ -202,13 +207,10 @@ class TreeObsForRailEnv(ObservationBuilder):
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
# position = self.env.agents_position[handle]
# orientation = self.env.agents_direction[handle]
possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
# observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]]
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], agent.direction]
root_observation = observation[:]
visited = set()
# Start from the current orientation, and see which transitions are available;
......@@ -337,40 +339,49 @@ class TreeObsForRailEnv(ObservationBuilder):
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
root_observation[3] + num_steps,
0]
0,
direction]
elif last_isTerminal:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
np.inf,
np.inf]
np.inf,
direction]
else:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction]]
self.distance_map[handle, position[0], position[1], direction],
direction]
"""
if last_isTarget:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
0]
0,
direction
]
elif last_isTerminal:
observation = [0,
other_target_encountered,
other_agent_encountered,
np.inf,
np.inf]
np.inf,
direction
]
else:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction]]
self.distance_map[handle, position[0], position[1], direction],
direction
]
# #############################
# #############################
......@@ -409,7 +420,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth - depth):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
return observation, visited
......@@ -532,7 +543,9 @@ class GlobalObsForRailEnv(ObservationBuilder):
obs_agents_state[agent2.position][4 + agent2.direction] = 1
obs_targets[agent2.target][1] += 1
return self.rail_obs, obs_agents_state, obs_targets
direction = self._get_one_hot_for_agent_direction(agent)
return self.rail_obs, obs_agents_state, obs_targets, direction
class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
......@@ -542,13 +555,15 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
- transition map array with dimensions (env.height, env.width, 16),
assuming 16 bits encoding of transitions, flipped in the direction of the agent
(the agent is always heding north on the flipped view).
(the agent is always heading north on the flipped view).
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
target and the positions of the other agents targets, also flipped depending on the agent's direction.
- A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
agents at their position coordinates, and the last channel containing the position of the given agent.
- A 4 elements array with one hot encoding of the direction.
"""
def __init__(self):
......@@ -603,7 +618,9 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1
obs_targets[agent2.target][1] += 1
return rail_obs, obs_agents_state, obs_targets
direction = self._get_one_hot_for_agent_direction(agent)
return rail_obs, obs_agents_state, obs_targets, direction
class LocalObsForRailEnv(ObservationBuilder):
......@@ -635,8 +652,8 @@ class LocalObsForRailEnv(ObservationBuilder):
# We build the transition map with a view_radius empty cells expansion on each side.
# This helps to collect the local transition map view when the agent is close to a border.
self.rail_obs = np.zeros((self.env.height + 2*self.view_radius,
self.env.width + 2*self.view_radius, 16))
self.rail_obs = np.zeros((self.env.height + 2 * self.view_radius,
self.env.width + 2 * self.view_radius, 16))
for i in range(self.env.height):
for j in range(self.env.width):
bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
......@@ -654,12 +671,12 @@ class LocalObsForRailEnv(ObservationBuilder):
# top_offset = max(0, agent.position[0] - 1 - self.view_radius)
# bottom_offset = min(0, agent.position[0] + 1 + self.view_radius)
local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0]+2*self.view_radius + 1,
agent.position[1]:agent.position[1]+2*self.view_radius + 1]
local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0] + 2 * self.view_radius + 1,
agent.position[1]:agent.position[1] + 2 * self.view_radius + 1]
obs_map_state = np.zeros((2*self.view_radius + 1, 2*self.view_radius + 1, 2))
obs_map_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 2))
obs_other_agents_state = np.zeros((2*self.view_radius + 1, 2*self.view_radius + 1, 4))
obs_other_agents_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 4))
def relative_pos(pos):
return [agent.position[0] - pos[0], agent.position[1] - pos[1]]
......@@ -684,15 +701,11 @@ class LocalObsForRailEnv(ObservationBuilder):
if is_in(target_rel_pos_2):
obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1
direction = np.zeros(4)
direction[agent.direction] = 1
direction = self._get_one_hot_for_agent_direction(agent)
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
# class LocalObsForRailEnvImproved(ObservationBuilder):
# """
# Returns a local observation around the given agent
# """
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment