Skip to content
Snippets Groups Projects
Commit 1a9c8c28 authored by gmollard's avatar gmollard
Browse files

addded global obs direction dependent

parent 238b9fd5
No related branches found
No related tags found
No related merge requests found
......@@ -483,13 +483,12 @@ class GlobalObsForRailEnv(ObservationBuilder):
- transition map array with dimensions (env.height, env.width, 16),
assuming 16 bits encoding of transitions.
- Three 2D arrays (map_height, map_width, 3) containing respectively the position of the given agent,
the position of its target and the positions of the other agents targets.
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
target and the positions of the other agents targets.
- A 3D array (map_height, map_width, 4) containing the one hot encoding of directions
- A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding
of the direction of the given agent and the 4 second channels containing the positions
of the other agents at their position coordinates.
- A 4 elements array with one of encoding of the direction of the agent of interest.
"""
def __init__(self):
......@@ -516,30 +515,100 @@ class GlobalObsForRailEnv(ObservationBuilder):
# self.targets[target_pos] += 1
def get(self, handle):
obs_map_state = np.zeros((self.env.height, self.env.width, 3))
obs_other_agents_state = np.zeros((self.env.height, self.env.width, 4))
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 8))
agents = self.env.agents
agent = agents[handle]
direction = np.zeros(4)
direction[agent.direction] = 1
agent_pos = agents[handle].position
obs_map_state[agent_pos][0] += 1
obs_map_state[agent.target][1] += 1
obs_agents_state[agent_pos][:4] = direction
obs_targets[agent.target][0] += 1
for i in range(len(agents)):
if i != handle: # TODO: handle used as index...?
agent2 = agents[i]
obs_other_agents_state[agent2.position][agent2.direction] = 1
obs_map_state[agent2.target][2] += 1
obs_agents_state[agent2.position][4 + agent2.direction] = 1
obs_targets[agent2.target][1] += 1
direction = np.zeros(4)
direction[agent.direction] = 1
return self.rail_obs, obs_agents_state, obs_targets
class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
"""
Gives a global observation of the entire rail environment.
The observation is composed of the following elements:
- transition map array with dimensions (env.height, env.width, 16),
assuming 16 bits encoding of transitions, flipped in the direction of the agent
(the agent is always heding north on the flipped view).
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
target and the positions of the other agents targets, also flipped depending on the agent's direction.
return self.rail_obs, obs_map_state, obs_other_agents_state, direction
- A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
agents at their position coordinates, and the last channel containing the position of the given agent.
"""
def __init__(self):
self.observation_space = ()
super(GlobalObsForRailEnvDirectionDependent, self).__init__()
def _set_env(self, env):
super()._set_env(env)
self.observation_space = [4, self.env.height, self.env.width]
def reset(self):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
# self.rail_obs[i, j] = np.array(
# list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
# self.targets = np.zeros(self.env.height, self.env.width)
# for target_pos in self.env.agents_target:
# self.targets[target_pos] += 1
def get(self, handle):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 5))
agents = self.env.agents
agent = agents[handle]
direction = agent.direction
idx = np.tile(np.arange(16), 2)
rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]]
if direction == 1:
rail_obs = np.flip(rail_obs, axis=1)
elif direction == 2:
rail_obs = np.flip(rail_obs)
elif direction == 3:
rail_obs = np.flip(rail_obs, axis=0)
agent_pos = agents[handle].position
obs_agents_state[agent_pos][0] = 1
obs_targets[agent.target][0] += 1
idx = np.tile(np.arange(4), 2)
for i in range(len(agents)):
if i != handle: # TODO: handle used as index...?
agent2 = agents[i]
obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1
obs_targets[agent2.target][1] += 1
return rail_obs, obs_agents_state, obs_targets
class LocalObsForRailEnv(ObservationBuilder):
"""
Gives a global observation of the entire rail environment.
Gives a local observation of the rail environment around the agent.
The observation is composed of the following elements:
- transition map array of the local environment around the given agent,
......@@ -620,3 +689,10 @@ class LocalObsForRailEnv(ObservationBuilder):
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
# class LocalObsForRailEnvImproved(ObservationBuilder):
# """
# Returns a local observation around the given agent
# """
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment