From ea2d5328743e7caf983500f99859e5c6e164d117 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 1 Sep 2019 08:56:21 -0400 Subject: [PATCH] updated global observation to represent malfunctions and differential speeds of agents. local grid observations will not be updated. It was never used for the challenge and is not well suited for the task at hand. --- flatland/envs/observations.py | 93 +++++------------------------------ 1 file changed, 13 insertions(+), 80 deletions(-) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index dc0edc9b..d7f5ce84 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -623,9 +623,11 @@ class GlobalObsForRailEnv(ObservationBuilder): - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent target and the positions of the other agents targets. - - A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding - of the direction of the given agent and the 4 second channels containing the positions - of the other agents at their position coordinates. + - A 3D array (map_height, map_width, 4) wtih + - first channel containing the agents position and direction + - second channel containing the other agents positions and diretions + - third channel containing agent malfunctions + - fourth channel containing agent fractional speeds """ def __init__(self): @@ -647,94 +649,25 @@ class GlobalObsForRailEnv(ObservationBuilder): def get(self, handle): obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 8)) + obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) agents = self.env.agents agent = agents[handle] direction = np.zeros(4) direction[agent.direction] = 1 agent_pos = agents[handle].position - obs_agents_state[agent_pos][:4] = direction - obs_targets[agent.target][0] += 1 + obs_agents_state[agent_pos][0] = direction + obs_targets[agent.target][0] = 1 for i in range(len(agents)): if i != handle: # TODO: handle used as index...? agent2 = agents[i] - obs_agents_state[agent2.position][4 + agent2.direction] = 1 - obs_targets[agent2.target][1] += 1 + obs_agents_state[agent2.position][1] = agent2.direction + obs_targets[agent2.target][1] = 1 + obs_agents_state[agent2.position][2] = agent2.malfunction_data['malfunction'] + obs_agents_state[agent2.position][3] = agent2.speed_data['speed'] - direction = self._get_one_hot_for_agent_direction(agent) - - return self.rail_obs, obs_agents_state, obs_targets, direction - - -class GlobalObsForRailEnvDirectionDependent(ObservationBuilder): - """ - Gives a global observation of the entire rail environment. - The observation is composed of the following elements: - - - transition map array with dimensions (env.height, env.width, 16), - assuming 16 bits encoding of transitions, flipped in the direction of the agent - (the agent is always heading north on the flipped view). - - - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent - target and the positions of the other agents targets, also flipped depending on the agent's direction. - - - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other - agents at their position coordinates, and the last channel containing the position of the given agent. - - - A 4 elements array with one hot encoding of the direction. - """ - - def __init__(self): - self.observation_space = () - super(GlobalObsForRailEnvDirectionDependent, self).__init__() - - def _set_env(self, env): - super()._set_env(env) - - self.observation_space = [4, self.env.height, self.env.width] - - def reset(self): - self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) - for i in range(self.rail_obs.shape[0]): - for j in range(self.rail_obs.shape[1]): - bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] - bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i, j] = np.array(bitlist) - - def get(self, handle): - obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - agents = self.env.agents - agent = agents[handle] - direction = agent.direction - - idx = np.tile(np.arange(16), 2) - - rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]] - - if direction == 1: - rail_obs = np.flip(rail_obs, axis=1) - elif direction == 2: - rail_obs = np.flip(rail_obs) - elif direction == 3: - rail_obs = np.flip(rail_obs, axis=0) - - agent_pos = agents[handle].position - obs_agents_state[agent_pos][0] = 1 - obs_targets[agent.target][0] += 1 - - idx = np.tile(np.arange(4), 2) - for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? - agent2 = agents[i] - obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1 - obs_targets[agent2.target][1] += 1 - - direction = self._get_one_hot_for_agent_direction(agent) - - return rail_obs, obs_agents_state, obs_targets, direction + return self.rail_obs, obs_agents_state, obs_targets class LocalObsForRailEnv(ObservationBuilder): -- GitLab