diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index c23d4345a03c761ad4c4ac1d936db817f8acc529..548c3e29a72aaa0390ab73f5550738f46f52c6c4 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -492,11 +492,12 @@ class GlobalObsForRailEnv(ObservationBuilder): - transition map array with dimensions (env.height, env.width, 16),\ assuming 16 bits encoding of transitions. - - A 3D array (map_height, map_width, 4) with + - A 3D array (map_height, map_width, 5) with - first channel containing the agents position and direction - second channel containing the other agents positions and diretion - third channel containing agent/other agent malfunctions - fourth channel containing agent/other agent fractional speeds + ' fifth channel containing number of agents in cell (only larger then one at start position) - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ target and the positions of the other agents targets. @@ -519,14 +520,16 @@ class GlobalObsForRailEnv(ObservationBuilder): def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - 1 - + obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1 + obs_agents_state[0] -= 1 # Set all values to -1 to avoid confusion with orientation + obs_agents_state[1] -= 1 # Set all values to -1 to avoid confusion with orientation agent = self.env.agents[handle] obs_agents_state[agent.position][0] = agent.direction obs_targets[agent.target][0] = 1 for i in range(len(self.env.agents)): other_agent = self.env.agents[i] + obs_agents_state[other_agent.position][4] += 1 if i != handle: obs_agents_state[other_agent.position][1] = other_agent.direction obs_targets[other_agent.target][1] = 1