Skip to content
Snippets Groups Projects
Commit 32f0f49a authored by gmollard's avatar gmollard
Browse files

Merge branch 'local_obs_implementation'

parents 539044ff 33c7a0a1
No related branches found
No related tags found
No related merge requests found
......@@ -490,8 +490,6 @@ class GlobalObsForRailEnv(ObservationBuilder):
of the other agents at their position coordinates.
- A 4 elements array with one of encoding of the direction of the agent of interest.
"""
def __init__(self):
......@@ -536,4 +534,89 @@ class GlobalObsForRailEnv(ObservationBuilder):
direction = np.zeros(4)
direction[agent.direction] = 1
return self.rail_obs, obs_map_state, obs_other_agents_state, direction
return self.rail_obs, obs_map_state, obs_other_agents_state, direction
class LocalObsForRailEnv(ObservationBuilder):
"""
Gives a global observation of the entire rail environment.
The observation is composed of the following elements:
- transition map array of the local environment around the given agent,
with dimensions (2*view_radius + 1, 2*view_radius + 1, 16),
assuming 16 bits encoding of transitions.
- Two 2D arrays containing respectively, if they are in the agent's vision range,
its target position, the positions of the other targets.
- A 3D array (map_height, map_width, 4) containing the one hot encoding of directions
of the other agents at their position coordinates, if they are in the agent's vision range.
- A 4 elements array with one hot encoding of the direction.
"""
def __init__(self, view_radius):
"""
:param view_radius:
"""
super(LocalObsForRailEnv, self).__init__()
self.view_radius = view_radius
def reset(self):
# We build the transition map with a view_radius empty cells expansion on each side.
# This helps to collect the local transition map view when the agent is close to a border.
self.rail_obs = np.zeros((self.env.height + 2*self.view_radius,
self.env.width + 2*self.view_radius, 16))
for i in range(self.env.height):
for j in range(self.env.width):
bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist)
# self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array(
# list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
def get(self, handle):
agents = self.env.agents
agent = agents[handle]
# left_offset = max(0, agent.position[1] - 1 - self.view_radius)
# right_offset = min(self.env.width, agent.position[1] + 1 + self.view_radius)
# top_offset = max(0, agent.position[0] - 1 - self.view_radius)
# bottom_offset = min(0, agent.position[0] + 1 + self.view_radius)
local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0]+2*self.view_radius +1,
agent.position[1]:agent.position[1]+2*self.view_radius +1]
obs_map_state = np.zeros((2*self.view_radius +1, 2*self.view_radius + 1, 2))
obs_other_agents_state = np.zeros((2*self.view_radius +1, 2*self.view_radius +1, 4))
def relative_pos(pos):
return [agent.position[0] - pos[0], agent.position[1] - pos[1]]
def is_in(rel_pos):
return (abs(rel_pos[0]) <= self.view_radius) and (abs(rel_pos[1]) <= self.view_radius)
target_rel_pos = relative_pos(agent.target)
if is_in(target_rel_pos):
obs_map_state[self.view_radius + np.array(target_rel_pos)][0] += 1
for i in range(len(agents)):
if i != handle: # TODO: handle used as index...?
agent2 = agents[i]
agent_2_rel_pos = relative_pos(agent2.position)
if is_in(agent_2_rel_pos):
obs_other_agents_state[self.view_radius + agent_2_rel_pos[0],
self.view_radius + agent_2_rel_pos[1]][agent2.direction] += 1
target_rel_pos_2 = relative_pos(agent2.position)
if is_in(target_rel_pos_2):
obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1
direction = np.zeros(4)
direction[agent.direction] = 1
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment