Skip to content
Snippets Groups Projects
Commit 7234b44a authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

fix convergence issue

parent e73b89ea
No related branches found
No related tags found
No related merge requests found
......@@ -383,8 +383,9 @@ class TreeObsForRailEnv(ObservationBuilder):
elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[pre_step][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[pre_step][ca])] == 1 and tot_dist < potential_conflict:
if direction != self.predicted_dir[pre_step][ca] \
and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
......@@ -394,7 +395,8 @@ class TreeObsForRailEnv(ObservationBuilder):
conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[post_step][ca])] == 1 and tot_dist < potential_conflict:
self.predicted_dir[post_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment