diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index dc0edc9b4960749a868635edb8df95a5608b629c..d7f5ce84eed1f94a6a35d4ca7a1657665ba1b347 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -623,9 +623,11 @@ class GlobalObsForRailEnv(ObservationBuilder): - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent target and the positions of the other agents targets. - - A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding - of the direction of the given agent and the 4 second channels containing the positions - of the other agents at their position coordinates. + - A 3D array (map_height, map_width, 4) wtih + - first channel containing the agents position and direction + - second channel containing the other agents positions and diretions + - third channel containing agent malfunctions + - fourth channel containing agent fractional speeds """ def __init__(self): @@ -647,94 +649,25 @@ class GlobalObsForRailEnv(ObservationBuilder): def get(self, handle): obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 8)) + obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) agents = self.env.agents agent = agents[handle] direction = np.zeros(4) direction[agent.direction] = 1 agent_pos = agents[handle].position - obs_agents_state[agent_pos][:4] = direction - obs_targets[agent.target][0] += 1 + obs_agents_state[agent_pos][0] = direction + obs_targets[agent.target][0] = 1 for i in range(len(agents)): if i != handle: # TODO: handle used as index...? agent2 = agents[i] - obs_agents_state[agent2.position][4 + agent2.direction] = 1 - obs_targets[agent2.target][1] += 1 + obs_agents_state[agent2.position][1] = agent2.direction + obs_targets[agent2.target][1] = 1 + obs_agents_state[agent2.position][2] = agent2.malfunction_data['malfunction'] + obs_agents_state[agent2.position][3] = agent2.speed_data['speed'] - direction = self._get_one_hot_for_agent_direction(agent) - - return self.rail_obs, obs_agents_state, obs_targets, direction - - -class GlobalObsForRailEnvDirectionDependent(ObservationBuilder): - """ - Gives a global observation of the entire rail environment. - The observation is composed of the following elements: - - - transition map array with dimensions (env.height, env.width, 16), - assuming 16 bits encoding of transitions, flipped in the direction of the agent - (the agent is always heading north on the flipped view). - - - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent - target and the positions of the other agents targets, also flipped depending on the agent's direction. - - - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other - agents at their position coordinates, and the last channel containing the position of the given agent. - - - A 4 elements array with one hot encoding of the direction. - """ - - def __init__(self): - self.observation_space = () - super(GlobalObsForRailEnvDirectionDependent, self).__init__() - - def _set_env(self, env): - super()._set_env(env) - - self.observation_space = [4, self.env.height, self.env.width] - - def reset(self): - self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) - for i in range(self.rail_obs.shape[0]): - for j in range(self.rail_obs.shape[1]): - bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] - bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i, j] = np.array(bitlist) - - def get(self, handle): - obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - agents = self.env.agents - agent = agents[handle] - direction = agent.direction - - idx = np.tile(np.arange(16), 2) - - rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]] - - if direction == 1: - rail_obs = np.flip(rail_obs, axis=1) - elif direction == 2: - rail_obs = np.flip(rail_obs) - elif direction == 3: - rail_obs = np.flip(rail_obs, axis=0) - - agent_pos = agents[handle].position - obs_agents_state[agent_pos][0] = 1 - obs_targets[agent.target][0] += 1 - - idx = np.tile(np.arange(4), 2) - for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? - agent2 = agents[i] - obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1 - obs_targets[agent2.target][1] += 1 - - direction = self._get_one_hot_for_agent_direction(agent) - - return rail_obs, obs_agents_state, obs_targets, direction + return self.rail_obs, obs_agents_state, obs_targets class LocalObsForRailEnv(ObservationBuilder):