From 1a9c8c2856d792e4533a75c70c821ff21d0ea3a7 Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume.mollard2@gmail.com> Date: Wed, 29 May 2019 10:18:31 +0200 Subject: [PATCH] addded global obs direction dependent --- flatland/envs/observations.py | 106 +++++++++++++++++++++++++++++----- 1 file changed, 91 insertions(+), 15 deletions(-) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 96d91579..f5f70aa1 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -483,13 +483,12 @@ class GlobalObsForRailEnv(ObservationBuilder): - transition map array with dimensions (env.height, env.width, 16), assuming 16 bits encoding of transitions. - - Three 2D arrays (map_height, map_width, 3) containing respectively the position of the given agent, - the position of its target and the positions of the other agents targets. + - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent + target and the positions of the other agents targets. - - A 3D array (map_height, map_width, 4) containing the one hot encoding of directions + - A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding + of the direction of the given agent and the 4 second channels containing the positions of the other agents at their position coordinates. - - - A 4 elements array with one of encoding of the direction of the agent of interest. """ def __init__(self): @@ -516,30 +515,100 @@ class GlobalObsForRailEnv(ObservationBuilder): # self.targets[target_pos] += 1 def get(self, handle): - obs_map_state = np.zeros((self.env.height, self.env.width, 3)) - obs_other_agents_state = np.zeros((self.env.height, self.env.width, 4)) + obs_targets = np.zeros((self.env.height, self.env.width, 2)) + obs_agents_state = np.zeros((self.env.height, self.env.width, 8)) agents = self.env.agents agent = agents[handle] + direction = np.zeros(4) + direction[agent.direction] = 1 agent_pos = agents[handle].position - obs_map_state[agent_pos][0] += 1 - obs_map_state[agent.target][1] += 1 + obs_agents_state[agent_pos][:4] = direction + obs_targets[agent.target][0] += 1 for i in range(len(agents)): if i != handle: # TODO: handle used as index...? agent2 = agents[i] - obs_other_agents_state[agent2.position][agent2.direction] = 1 - obs_map_state[agent2.target][2] += 1 + obs_agents_state[agent2.position][4 + agent2.direction] = 1 + obs_targets[agent2.target][1] += 1 - direction = np.zeros(4) - direction[agent.direction] = 1 + return self.rail_obs, obs_agents_state, obs_targets + + +class GlobalObsForRailEnvDirectionDependent(ObservationBuilder): + """ + Gives a global observation of the entire rail environment. + The observation is composed of the following elements: + + - transition map array with dimensions (env.height, env.width, 16), + assuming 16 bits encoding of transitions, flipped in the direction of the agent + (the agent is always heding north on the flipped view). + + - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent + target and the positions of the other agents targets, also flipped depending on the agent's direction. - return self.rail_obs, obs_map_state, obs_other_agents_state, direction + - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other + agents at their position coordinates, and the last channel containing the position of the given agent. + """ + + def __init__(self): + self.observation_space = () + super(GlobalObsForRailEnvDirectionDependent, self).__init__() + + def _set_env(self, env): + super()._set_env(env) + + self.observation_space = [4, self.env.height, self.env.width] + + def reset(self): + self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) + for i in range(self.rail_obs.shape[0]): + for j in range(self.rail_obs.shape[1]): + bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]] + bitlist = [0] * (16 - len(bitlist)) + bitlist + self.rail_obs[i, j] = np.array(bitlist) + # self.rail_obs[i, j] = np.array( + # list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int) + + # self.targets = np.zeros(self.env.height, self.env.width) + # for target_pos in self.env.agents_target: + # self.targets[target_pos] += 1 + + def get(self, handle): + obs_targets = np.zeros((self.env.height, self.env.width, 2)) + obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) + agents = self.env.agents + agent = agents[handle] + direction = agent.direction + + idx = np.tile(np.arange(16), 2) + + rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]] + + if direction == 1: + rail_obs = np.flip(rail_obs, axis=1) + elif direction == 2: + rail_obs = np.flip(rail_obs) + elif direction == 3: + rail_obs = np.flip(rail_obs, axis=0) + + agent_pos = agents[handle].position + obs_agents_state[agent_pos][0] = 1 + obs_targets[agent.target][0] += 1 + + idx = np.tile(np.arange(4), 2) + for i in range(len(agents)): + if i != handle: # TODO: handle used as index...? + agent2 = agents[i] + obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1 + obs_targets[agent2.target][1] += 1 + + return rail_obs, obs_agents_state, obs_targets class LocalObsForRailEnv(ObservationBuilder): """ - Gives a global observation of the entire rail environment. + Gives a local observation of the rail environment around the agent. The observation is composed of the following elements: - transition map array of the local environment around the given agent, @@ -620,3 +689,10 @@ class LocalObsForRailEnv(ObservationBuilder): return local_rail_obs, obs_map_state, obs_other_agents_state, direction + +# class LocalObsForRailEnvImproved(ObservationBuilder): +# """ +# Returns a local observation around the given agent +# """ + + -- GitLab