Skip to content
Snippets Groups Projects
Commit 51c747b6 authored by gmollard's avatar gmollard
Browse files

added direction of other agents in global observation

parent 72edfe4b
No related branches found
No related tags found
No related merge requests found
...@@ -203,7 +203,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -203,7 +203,7 @@ class TreeObsForRailEnv(ObservationBuilder):
num_transitions = np.count_nonzero(possible_transitions) num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position # Root node - current position
# observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]] observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position,direc agent.direction)]]
root_observation = observation[:] root_observation = observation[:]
visited = set() visited = set()
# Start from the current orientation, and see which transitions are available; # Start from the current orientation, and see which transitions are available;
...@@ -478,11 +478,15 @@ class GlobalObsForRailEnv(ObservationBuilder): ...@@ -478,11 +478,15 @@ class GlobalObsForRailEnv(ObservationBuilder):
- transition map array with dimensions (env.height, env.width, 16), - transition map array with dimensions (env.height, env.width, 16),
assuming 16 bits encoding of transitions. assuming 16 bits encoding of transitions.
- Four 2D arrays containing respectively the position of the given agent, - Three 2D arrays (map_height, map_width, 3) containing respectively the position of the given agent,
the position of its target, the positions of the other agents and of the position of its target and the positions of the other agents targets.
their target.
- A 3D array (map_height, map_width, 4) containing the one hot encoding of directions
of the other agents at their position coordinates.
- A 4 elements array with one of encoding of the direction of the agent of interest.
- A 4 elements array with one of encoding of the direction.
""" """
def __init__(self): def __init__(self):
...@@ -503,21 +507,22 @@ class GlobalObsForRailEnv(ObservationBuilder): ...@@ -503,21 +507,22 @@ class GlobalObsForRailEnv(ObservationBuilder):
# self.targets[target_pos] += 1 # self.targets[target_pos] += 1
def get(self, handle): def get(self, handle):
obs = np.zeros((4, self.env.height, self.env.width)) obs_map_state = np.zeros((self.env.height, self.env.width, 3))
obs_other_agents_state = np.zeros((self.env.height, self.env.width, 4))
agents = self.env.agents agents = self.env.agents
agent = agents[handle] agent = agents[handle]
agent_pos = agents[handle].position agent_pos = agents[handle].position
obs[0][agent_pos] += 1 obs_map_state[agent_pos][0] += 1
obs[1][agent.target] += 1 obs_map_state[agent.target][1] += 1
for i in range(len(agents)): for i in range(len(agents)):
if i != handle: # TODO: handle used as index...? if i != handle: # TODO: handle used as index...?
agent2 = agents[i] agent2 = agents[i]
obs[3][agent2.position] += 1 obs_other_agents_state[agent2.position][agent2.direction] = 1
obs[2][agent2.target] += 1 obs_map_state[agent2.target][2] += 1
direction = np.zeros(4) direction = np.zeros(4)
direction[agent.direction] = 1 direction[agent.direction] = 1
return self.rail_obs, obs, direction return self.rail_obs, obs_map_state, obs_other_agents_state, direction
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment