diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 3e8583d996635bf58552922707815984f8e76f85..2887610b06d5455bd7612d3c6f0ad386e086c6be 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -358,21 +358,26 @@ class TreeObsForRailEnv(ObservationBuilder): if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0): conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position) for ca in conflicting_agent[0]: - if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict: potential_conflict = tot_dist + if self.env.dones[ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist # Look for opposing paths at distance num_step-1 elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict: potential_conflict = tot_dist + if self.env.dones[ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist # Look for opposing paths at distance num_step+1 elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict: potential_conflict = tot_dist + if self.env.dones[ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist if position in self.location_has_target and position != agent.target: if tot_dist < other_target_encountered: diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 2af0d8b962fc5b31e2e3c7607c06b8613c15a47a..abe623ae173a593e265cff7d4d88eb323e16b08e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -334,6 +334,7 @@ class RailEnv(Environment): if np.equal(agent.position, agent.target).all(): self.dones[i_agent] = True + agent.moving = False else: self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']