From ffa48c3b035181a6ca565ac261b4e83a58502112 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Fri, 14 Jun 2019 01:14:02 +0200 Subject: [PATCH] removed bug where agent would see its own prediction as conflict --- examples/training_example.py | 2 +- flatland/envs/observations.py | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/examples/training_example.py b/examples/training_example.py index ad8396a..218f133 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -16,7 +16,7 @@ env = RailEnv(width=20, height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), obs_builder_object=TreeObservation, - number_of_agents=1) + number_of_agents=2) # Import your own Agent or use RLlib to train agents on Flatland diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index b492502..931fabc 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -334,18 +334,24 @@ class TreeObsForRailEnv(ObservationBuilder): int_position = coordinate_to_position(self.env.width, [position]) pre_step = max(0, num_steps - 1) post_step = min(self.max_prediction_depth - 1, num_steps + 1) + # Look for opposing paths at distance num_step if int_position in np.delete(self.predicted_pos[num_steps], handle): - conflicting_agent = np.where(np.delete(self.predicted_pos[num_steps], handle) == int_position)[0][0] - if direction != self.predicted_dir[num_steps][conflicting_agent]: - potential_conflict = 1 + conflicting_agent = np.where(np.delete(self.predicted_pos[num_steps], handle) == int_position) + for ca in conflicting_agent: + if direction != self.predicted_dir[num_steps][ca[0]]: + potential_conflict = 1 + # Look for opposing paths at distance num_step-1 elif int_position in np.delete(self.predicted_pos[pre_step], handle): - conflicting_agent = np.where(np.delete(self.predicted_pos[pre_step], handle) == int_position)[0][0] - if direction != self.predicted_dir[pre_step][conflicting_agent]: - potential_conflict = 1 + conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) + for ca in conflicting_agent: + if direction != self.predicted_dir[pre_step][ca[0]]: + potential_conflict = 1 + # Look for opposing paths at distance num_step+1 elif int_position in np.delete(self.predicted_pos[post_step], handle): - conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)[0][0] - if direction != self.predicted_dir[post_step][conflicting_agent]: - potential_conflict = 1 + conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position) + for ca in conflicting_agent: + if direction != self.predicted_dir[post_step][ca[0]]: + potential_conflict = 1 if position in self.location_has_target and position != agent.target: if num_steps < other_target_encountered: -- GitLab