Commit 46a42274 authored by Erik Nygren's avatar Erik Nygren
Browse files

added direction to the conflict detection

parent 28a67a9d
Pipeline #1090 failed with stages
in 9 minutes and 17 seconds
......@@ -16,7 +16,7 @@ env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
obs_builder_object=TreeObservation,
number_of_agents=2)
number_of_agents=1)
# Import your own Agent or use RLlib to train agents on Flatland
......
......@@ -331,17 +331,21 @@ class TreeObsForRailEnv(ObservationBuilder):
# Register possible conflict
if self.predictor and num_steps < self.max_prediction_depth:
if coordinate_to_position(self.env.width, [position]) in np.delete(self.predicted_pos[num_steps],
handle):
potential_conflict = 1
if coordinate_to_position(self.env.width, [position]) in np.delete(
self.predicted_pos[max(0, num_steps - 1)],
handle):
potential_conflict = 1
if coordinate_to_position(self.env.width, [position]) in np.delete(
self.predicted_pos[min(self.max_prediction_depth - 1, num_steps + 1)],
handle):
potential_conflict = 1
int_position = coordinate_to_position(self.env.width, [position])
pre_step = max(0, num_steps - 1)
post_step = min(self.max_prediction_depth - 1, num_steps + 1)
if int_position in np.delete(self.predicted_pos[num_steps], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[num_steps], handle) == int_position)[0][0]
if direction != self.predicted_dir[num_steps][conflicting_agent]:
potential_conflict = 1
elif int_position in np.delete(self.predicted_pos[pre_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[pre_step], handle) == int_position)[0][0]
if direction != self.predicted_dir[pre_step][conflicting_agent]:
potential_conflict = 1
elif int_position in np.delete(self.predicted_pos[post_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)[0][0]
if direction != self.predicted_dir[post_step][conflicting_agent]:
potential_conflict = 1
if position in self.location_has_target and position != agent.target:
if num_steps < other_target_encountered:
......
......@@ -135,7 +135,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
if target_dist < min_dist:
min_dist = target_dist
new_direction = direct
new_position = self._new_position(agent.position, new_direction)
new_position = self._new_position(agent.position, new_direction)
agent.position = new_position
agent.direction = new_direction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment