diff --git a/examples/training_example.py b/examples/training_example.py index 218f133efe4e79064ae94b24c61aa2d6f2ce2d09..ad8396a57c705dd99ab679c0cb76b688c9e99c75 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -16,7 +16,7 @@ env = RailEnv(width=20, height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), obs_builder_object=TreeObservation, - number_of_agents=2) + number_of_agents=1) # Import your own Agent or use RLlib to train agents on Flatland diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 439d538bb431a11a6cbfb3bfbe4b7d3adc93c984..b4925024c7717e7673b52a593f9a735fc9f001e1 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -331,17 +331,21 @@ class TreeObsForRailEnv(ObservationBuilder): # Register possible conflict if self.predictor and num_steps < self.max_prediction_depth: - if coordinate_to_position(self.env.width, [position]) in np.delete(self.predicted_pos[num_steps], - handle): - potential_conflict = 1 - if coordinate_to_position(self.env.width, [position]) in np.delete( - self.predicted_pos[max(0, num_steps - 1)], - handle): - potential_conflict = 1 - if coordinate_to_position(self.env.width, [position]) in np.delete( - self.predicted_pos[min(self.max_prediction_depth - 1, num_steps + 1)], - handle): - potential_conflict = 1 + int_position = coordinate_to_position(self.env.width, [position]) + pre_step = max(0, num_steps - 1) + post_step = min(self.max_prediction_depth - 1, num_steps + 1) + if int_position in np.delete(self.predicted_pos[num_steps], handle): + conflicting_agent = np.where(np.delete(self.predicted_pos[num_steps], handle) == int_position)[0][0] + if direction != self.predicted_dir[num_steps][conflicting_agent]: + potential_conflict = 1 + elif int_position in np.delete(self.predicted_pos[pre_step], handle): + conflicting_agent = np.where(np.delete(self.predicted_pos[pre_step], handle) == int_position)[0][0] + if direction != self.predicted_dir[pre_step][conflicting_agent]: + potential_conflict = 1 + elif int_position in np.delete(self.predicted_pos[post_step], handle): + conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)[0][0] + if direction != self.predicted_dir[post_step][conflicting_agent]: + potential_conflict = 1 if position in self.location_has_target and position != agent.target: if num_steps < other_target_encountered: diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 366acfebeb8c5ba4f56138c9bab0bcbb904a7b11..f5f559a1fb83f71ac9ca2ac514bc3d9f6a4b024f 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -135,7 +135,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): if target_dist < min_dist: min_dist = target_dist new_direction = direct - new_position = self._new_position(agent.position, new_direction) + new_position = self._new_position(agent.position, new_direction) agent.position = new_position agent.direction = new_direction