diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index b5d9c2e0ba6db708c394aa99ca73f1e42bb4cb43..8d9ed261acd61578de11ac6c348903be8b2cd42d 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -178,6 +178,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_pos = {} self.predicted_dir = {} self.predictions = self.predictor.get(self.distance_map) + print(self.predictions) for t in range(len(self.predictions[0])): pos_list = [] dir_list = [] @@ -334,6 +335,14 @@ class TreeObsForRailEnv(ObservationBuilder): if coordinate_to_position(self.env.width, [position]) in np.delete(self.predicted_pos[num_steps], handle): potential_conflict = 1 + if coordinate_to_position(self.env.width, [position]) in np.delete( + self.predicted_pos[max(0, num_steps - 1)], + handle): + potential_conflict = 1 + if coordinate_to_position(self.env.width, [position]) in np.delete( + self.predicted_pos[min(self.max_prediction_depth - 1, num_steps + 1)], + handle): + potential_conflict = 1 if position in self.location_has_target and position != agent.target: if num_steps < other_target_encountered: