Skip to content
Snippets Groups Projects
Commit 0a6d0919 authored by Erik Nygren's avatar Erik Nygren
Browse files

minor updates. still ways to go...

parent 6e2c1c9d
No related branches found
No related tags found
No related merge requests found
...@@ -701,7 +701,7 @@ class LocalObsForRailEnv(ObservationBuilder): ...@@ -701,7 +701,7 @@ class LocalObsForRailEnv(ObservationBuilder):
with dimensions (view_height,2*view_width+1, 16), with dimensions (view_height,2*view_width+1, 16),
assuming 16 bits encoding of transitions. assuming 16 bits encoding of transitions.
- Two 3D arrays (view_height,2*view_width+1, 2) containing respectively, - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively,
if they are in the agent's vision range, its target position, the positions of the other targets. if they are in the agent's vision range, its target position, the positions of the other targets.
- A 3D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions - A 3D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions
...@@ -744,16 +744,21 @@ class LocalObsForRailEnv(ObservationBuilder): ...@@ -744,16 +744,21 @@ class LocalObsForRailEnv(ObservationBuilder):
# Collect the rail information in the local field of view # Collect the rail information in the local field of view
local_rail_obs = self.field_of_view(agent_rel_pos, agent.direction, state=self.rail_obs) local_rail_obs = self.field_of_view(agent_rel_pos, agent.direction, state=self.rail_obs)
# Locate observed agents and their coresponding targets
obs_map_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 2))
obs_other_agents_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 4))
# Collect visible cells as set to be plotted # Collect visible cells as set to be plotted
visited = self.field_of_view(agent.position, agent.direction) visited, rel_coords = self.field_of_view(agent.position, agent.direction)
# Add the visible cells to the observed cells # Add the visible cells to the observed cells
self.env.dev_obs_dict[handle] = visited self.env.dev_obs_dict[handle] = set(visited)
# Locate observed agents and their coresponding targets
obs_map_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 2))
obs_other_agents_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 4))
for i in range(len(agents)):
temp_agent = agents[i]
if temp_agent.target in visited:
location = np.where(visited == temp_agent.target)
print("I see my target", location, handle)
direction = self._get_one_hot_for_agent_direction(agent) direction = self._get_one_hot_for_agent_direction(agent)
return local_rail_obs, obs_map_state, obs_other_agents_state, direction return local_rail_obs, obs_map_state, obs_other_agents_state, direction
...@@ -783,30 +788,35 @@ class LocalObsForRailEnv(ObservationBuilder): ...@@ -783,30 +788,35 @@ class LocalObsForRailEnv(ObservationBuilder):
origin = (position[0] - self.center, position[1] + self.view_width) origin = (position[0] - self.center, position[1] + self.view_width)
else: else:
origin = (position[0] + self.view_width, position[1] + self.center) origin = (position[0] + self.view_width, position[1] + self.center)
visible = set() visible = list()
rel_coords = list()
for h in range(self.view_height): for h in range(self.view_height):
for w in range(2 * self.view_width + 1): for w in range(2 * self.view_width + 1):
if direction == 0: if direction == 0:
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
visible.add((origin[0] - h, origin[1] + w)) visible.append((origin[0] - h, origin[1] + w))
rel_coords.append((h, w))
if data_collection: if data_collection:
temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :] temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
elif direction == 1: elif direction == 1:
if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width: if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
visible.add((origin[0] + w, origin[1] + h)) visible.append((origin[0] + w, origin[1] + h))
rel_coords.append((h, w))
if data_collection: if data_collection:
temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :] temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
elif direction == 2: elif direction == 2:
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
visible.add((origin[0] + h, origin[1] - w)) visible.append((origin[0] + h, origin[1] - w))
rel_coords.append((h, w))
if data_collection: if data_collection:
temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :] temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
else: else:
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
visible.add((origin[0] - w, origin[1] - h)) visible.append((origin[0] - w, origin[1] - h))
rel_coords.append((h, w))
if data_collection: if data_collection:
temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :] temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
if data_collection: if data_collection:
return temp_visible_data return temp_visible_data
else: else:
return visible return visible, rel_coords
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment