Commit ffa48c3b authored by Erik Nygren's avatar Erik Nygren
Browse files

removed bug where agent would see its own prediction as conflict

parent 46a42274
Pipeline #1091 failed with stages
in 9 minutes and 33 seconds
......@@ -16,7 +16,7 @@ env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
obs_builder_object=TreeObservation,
number_of_agents=1)
number_of_agents=2)
# Import your own Agent or use RLlib to train agents on Flatland
......
......@@ -334,18 +334,24 @@ class TreeObsForRailEnv(ObservationBuilder):
int_position = coordinate_to_position(self.env.width, [position])
pre_step = max(0, num_steps - 1)
post_step = min(self.max_prediction_depth - 1, num_steps + 1)
# Look for opposing paths at distance num_step
if int_position in np.delete(self.predicted_pos[num_steps], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[num_steps], handle) == int_position)[0][0]
if direction != self.predicted_dir[num_steps][conflicting_agent]:
potential_conflict = 1
conflicting_agent = np.where(np.delete(self.predicted_pos[num_steps], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[num_steps][ca[0]]:
potential_conflict = 1
# Look for opposing paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[pre_step], handle) == int_position)[0][0]
if direction != self.predicted_dir[pre_step][conflicting_agent]:
potential_conflict = 1
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca[0]]:
potential_conflict = 1
# Look for opposing paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)[0][0]
if direction != self.predicted_dir[post_step][conflicting_agent]:
potential_conflict = 1
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca[0]]:
potential_conflict = 1
if position in self.location_has_target and position != agent.target:
if num_steps < other_target_encountered:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment