From 245c9cc28044bd2ee3f6140e70ad7b25544bb182 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 25 Jul 2019 15:29:07 -0400 Subject: [PATCH] At switches now conflicts are only detected if agent has option to choose head on colliding path --- flatland/core/env_observation_builder.py | 1 + flatland/envs/observations.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 060785f5..4acdf16f 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -74,6 +74,7 @@ class ObservationBuilder: direction[agent.direction] = 1 return direction + class DummyObservationBuilder(ObservationBuilder): """ DummyObservationBuilder class which returns dummy observations diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 498d98ee..8b0c94f1 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -270,6 +270,7 @@ class TreeObsForRailEnv(ObservationBuilder): observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0] visited = set() + # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # If only one transition is possible, the tree is oriented with this transition as the forward branch. @@ -289,6 +290,7 @@ class TreeObsForRailEnv(ObservationBuilder): # add cells filled with infinity if no transition is possible observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth) self.env.dev_obs_dict[handle] = visited + return observation def _num_cells_to_fill_in(self, remaining_depth): @@ -362,11 +364,12 @@ class TreeObsForRailEnv(ObservationBuilder): pre_step = max(0, tot_dist - 1) post_step = min(self.max_prediction_depth - 1, tot_dist + 1) - # Look for conflicting paths at distance num_step + # Look for conflicting paths at distance tot_dist if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0): conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position) for ca in conflicting_agent[0]: - if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict: + if direction != self.predicted_dir[tot_dist][ca] and cell_transitions[self._reverse_dir( + self.predicted_dir[tot_dist][ca])] == 1 and tot_dist < potential_conflict: potential_conflict = tot_dist if self.env.dones[ca] and tot_dist < potential_conflict: potential_conflict = tot_dist @@ -375,7 +378,8 @@ class TreeObsForRailEnv(ObservationBuilder): elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) for ca in conflicting_agent[0]: - if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict: + if direction != self.predicted_dir[pre_step][ca] and cell_transitions[self._reverse_dir( + self.predicted_dir[pre_step][ca])] == 1 and tot_dist < potential_conflict: potential_conflict = tot_dist if self.env.dones[ca] and tot_dist < potential_conflict: potential_conflict = tot_dist @@ -384,7 +388,8 @@ class TreeObsForRailEnv(ObservationBuilder): elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) for ca in conflicting_agent[0]: - if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict: + if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir( + self.predicted_dir[post_step][ca])] == 1 and tot_dist < potential_conflict: potential_conflict = tot_dist if self.env.dones[ca] and tot_dist < potential_conflict: potential_conflict = tot_dist @@ -566,6 +571,9 @@ class TreeObsForRailEnv(ObservationBuilder): if self.predictor: self.predictor._set_env(self.env) + def _reverse_dir(self, direction): + return int((direction + 2) % 4) + class GlobalObsForRailEnv(ObservationBuilder): """ -- GitLab