From 34a3f75b2df6645e3b9625bdb6db41fc1f82a567 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Fri, 14 Jun 2019 00:20:58 +0200 Subject: [PATCH] simple implementation of broader conflict detection window --- flatland/envs/observations.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index b5d9c2e..8d9ed26 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -178,6 +178,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_pos = {} self.predicted_dir = {} self.predictions = self.predictor.get(self.distance_map) + print(self.predictions) for t in range(len(self.predictions[0])): pos_list = [] dir_list = [] @@ -334,6 +335,14 @@ class TreeObsForRailEnv(ObservationBuilder): if coordinate_to_position(self.env.width, [position]) in np.delete(self.predicted_pos[num_steps], handle): potential_conflict = 1 + if coordinate_to_position(self.env.width, [position]) in np.delete( + self.predicted_pos[max(0, num_steps - 1)], + handle): + potential_conflict = 1 + if coordinate_to_position(self.env.width, [position]) in np.delete( + self.predicted_pos[min(self.max_prediction_depth - 1, num_steps + 1)], + handle): + potential_conflict = 1 if position in self.location_has_target and position != agent.target: if num_steps < other_target_encountered: -- GitLab