diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 3e8583d996635bf58552922707815984f8e76f85..2887610b06d5455bd7612d3c6f0ad386e086c6be 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -358,21 +358,26 @@ class TreeObsForRailEnv(ObservationBuilder): if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0): conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position) for ca in conflicting_agent[0]: - if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict: potential_conflict = tot_dist + if self.env.dones[ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist # Look for opposing paths at distance num_step-1 elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict: potential_conflict = tot_dist + if self.env.dones[ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist # Look for opposing paths at distance num_step+1 elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict: potential_conflict = tot_dist + if self.env.dones[ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist if position in self.location_has_target and position != agent.target: if tot_dist < other_target_encountered: