From ff2ea11da7abccddd42b5dd7920e7875dc3304af Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sat, 15 Jun 2019 13:06:26 +0200 Subject: [PATCH] updated how distance from agent is measured to detect conflicts! --- flatland/envs/observations.py | 55 +++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 931fabc7..007f6939 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -273,7 +273,7 @@ class TreeObsForRailEnv(ObservationBuilder): if possible_transitions[branch_direction]: new_cell = self._new_position(agent.position, branch_direction) branch_observation, branch_visited = \ - self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) + self._explore_branch(handle, new_cell, branch_direction, root_observation, 0, 1) observation = observation + branch_observation visited = visited.union(branch_visited) else: @@ -286,7 +286,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.env.dev_obs_dict[handle] = visited return observation - def _explore_branch(self, handle, position, direction, root_observation, depth): + def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth): """ Utility function to compute tree-based observations. """ @@ -332,26 +332,31 @@ class TreeObsForRailEnv(ObservationBuilder): # Register possible conflict if self.predictor and num_steps < self.max_prediction_depth: int_position = coordinate_to_position(self.env.width, [position]) - pre_step = max(0, num_steps - 1) - post_step = min(self.max_prediction_depth - 1, num_steps + 1) - # Look for opposing paths at distance num_step - if int_position in np.delete(self.predicted_pos[num_steps], handle): - conflicting_agent = np.where(np.delete(self.predicted_pos[num_steps], handle) == int_position) - for ca in conflicting_agent: - if direction != self.predicted_dir[num_steps][ca[0]]: - potential_conflict = 1 - # Look for opposing paths at distance num_step-1 - elif int_position in np.delete(self.predicted_pos[pre_step], handle): - conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) - for ca in conflicting_agent: - if direction != self.predicted_dir[pre_step][ca[0]]: - potential_conflict = 1 - # Look for opposing paths at distance num_step+1 - elif int_position in np.delete(self.predicted_pos[post_step], handle): - conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position) - for ca in conflicting_agent: - if direction != self.predicted_dir[post_step][ca[0]]: - potential_conflict = 1 + if tot_dist < self.max_prediction_depth: + pre_step = max(0, tot_dist - 1) + post_step = min(self.max_prediction_depth - 1, tot_dist + 1) + + # Look for opposing paths at distance num_step + if int_position in np.delete(self.predicted_pos[tot_dist], handle): + conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position) + for ca in conflicting_agent: + if direction != self.predicted_dir[tot_dist][ca[0]]: + potential_conflict = 1 + # print("Potential Conflict",position,handle,ca[0],tot_dist,depth) + # Look for opposing paths at distance num_step-1 + elif int_position in np.delete(self.predicted_pos[pre_step], handle): + conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) + for ca in conflicting_agent: + if direction != self.predicted_dir[pre_step][ca[0]]: + potential_conflict = 1 + # print("Potential Conflict", position,handle,ca[0],pre_step,depth) + # Look for opposing paths at distance num_step+1 + elif int_position in np.delete(self.predicted_pos[post_step], handle): + conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position) + for ca in conflicting_agent: + if direction != self.predicted_dir[post_step][ca[0]]: + potential_conflict = 1 + # print("Potential Conflict", position,handle,ca[0],post_step,depth) if position in self.location_has_target and position != agent.target: if num_steps < other_target_encountered: @@ -395,7 +400,7 @@ class TreeObsForRailEnv(ObservationBuilder): direction = np.argmax(cell_transitions) position = self._new_position(position, direction) num_steps += 1 - + tot_dist += 1 elif num_transitions > 0: # Switch detected last_isSwitch = True @@ -499,7 +504,7 @@ class TreeObsForRailEnv(ObservationBuilder): branch_observation, branch_visited = self._explore_branch(handle, new_cell, (branch_direction + 2) % 4, - new_root_observation, + new_root_observation, tot_dist + 1, depth + 1) observation = observation + branch_observation if len(branch_visited) != 0: @@ -509,7 +514,7 @@ class TreeObsForRailEnv(ObservationBuilder): branch_observation, branch_visited = self._explore_branch(handle, new_cell, branch_direction, - new_root_observation, + new_root_observation, tot_dist + 1, depth + 1) observation = observation + branch_observation if len(branch_visited) != 0: -- GitLab