Skip to content
Snippets Groups Projects
Commit ae53bf1a authored by Erik Nygren's avatar Erik Nygren
Browse files

Refactoring of the LocalObForRailEnv observation.

Now turns with the agent, encodes local information which can easily be enhanced by predictions, distance map and other things.
parent 0a6d0919
No related branches found
No related tags found
No related merge requests found
......@@ -14,11 +14,11 @@ np.random.seed(1)
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=20,
height=20,
env = RailEnv(width=50,
height=50,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
obs_builder_object=LocalGridObs,
number_of_agents=2)
number_of_agents=5)
env_renderer = RenderTool(env, gl="PILSVG", )
......
......@@ -704,7 +704,7 @@ class LocalObsForRailEnv(ObservationBuilder):
- Two 2D arrays (view_height,2*view_width+1, 2) containing respectively,
if they are in the agent's vision range, its target position, the positions of the other targets.
- A 3D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions
- A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions
of the other agents at their position coordinates, if they are in the agent's vision range.
- A 4 elements array with one hot encoding of the direction.
......@@ -724,43 +724,52 @@ class LocalObsForRailEnv(ObservationBuilder):
# We build the transition map with a view_radius empty cells expansion on each side.
# This helps to collect the local transition map view when the agent is close to a border.
self.max_padding = max(self.view_width, self.view_height)
self.rail_obs = np.zeros((self.env.height + 2 * self.max_padding,
self.env.width + 2 * self.max_padding, 16))
self.rail_obs = np.zeros((self.env.height,
self.env.width, 16))
for i in range(self.env.height):
for j in range(self.env.width):
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i + self.view_height, j + self.view_width] = np.array(bitlist)
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle):
agents = self.env.agents
agent = agents[handle]
agent_rel_pos = [0, 0]
# Correct agents position for padding
agent_rel_pos[0] = agent.position[0] + self.max_padding
agent_rel_pos[1] = agent.position[1] + self.max_padding
# Collect the rail information in the local field of view
local_rail_obs = self.field_of_view(agent_rel_pos, agent.direction, state=self.rail_obs)
# agent_rel_pos[0] = agent.position[0] + self.max_padding
# agent_rel_pos[1] = agent.position[1] + self.max_padding
# Collect visible cells as set to be plotted
visited, rel_coords = self.field_of_view(agent.position, agent.direction)
visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
local_rail_obs = None
# Add the visible cells to the observed cells
self.env.dev_obs_dict[handle] = set(visited)
# Locate observed agents and their coresponding targets
obs_map_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 2))
obs_other_agents_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 4))
for i in range(len(agents)):
temp_agent = agents[i]
if temp_agent.target in visited:
location = np.where(visited == temp_agent.target)
print("I see my target", location, handle)
direction = self._get_one_hot_for_agent_direction(agent)
local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
_idx = 0
for pos in visited:
curr_rel_coord = rel_coords[_idx]
local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
if pos == agent.target:
obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
else:
for tmp_agent in agents:
if pos == tmp_agent.target:
obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
if pos != agent.position:
for tmp_agent in agents:
if pos == tmp_agent.position:
obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
tmp_agent.direction]
_idx += 1
direction = np.identity(4)[agent.direction]
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
def get_many(self, handles=None):
......@@ -796,26 +805,26 @@ class LocalObsForRailEnv(ObservationBuilder):
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
visible.append((origin[0] - h, origin[1] + w))
rel_coords.append((h, w))
if data_collection:
temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
elif direction == 1:
if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
visible.append((origin[0] + w, origin[1] + h))
rel_coords.append((h, w))
if data_collection:
temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
elif direction == 2:
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
visible.append((origin[0] + h, origin[1] - w))
rel_coords.append((h, w))
if data_collection:
temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
else:
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
visible.append((origin[0] - w, origin[1] - h))
rel_coords.append((h, w))
if data_collection:
temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
if data_collection:
return temp_visible_data
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment