diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 651d83520ee1f492ea71ba6ddf82cfa5f9093964..fa77b58a53ef63767c2690edc782b6699f6c8459 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -492,8 +492,11 @@ class GlobalObsForRailEnv(ObservationBuilder): self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) for i in range(self.rail_obs.shape[0]): for j in range(self.rail_obs.shape[1]): - self.rail_obs[i, j] = np.array( - list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int) + bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]] + bitlist = [0] * (16 - len(bitlist)) + bitlist + self.rail_obs[i, j] = np.array(bitlist) + # self.rail_obs[i, j] = np.array( + # list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int) # self.targets = np.zeros(self.env.height, self.env.width) # for target_pos in self.env.agents_target: