Skip to content
Snippets Groups Projects
Commit ff2ea11d authored by Erik Nygren's avatar Erik Nygren
Browse files

updated how distance from agent is measured to detect conflicts!

parent 920032da
No related branches found
No related tags found
No related merge requests found
...@@ -273,7 +273,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -273,7 +273,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if possible_transitions[branch_direction]: if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction) new_cell = self._new_position(agent.position, branch_direction)
branch_observation, branch_visited = \ branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) self._explore_branch(handle, new_cell, branch_direction, root_observation, 0, 1)
observation = observation + branch_observation observation = observation + branch_observation
visited = visited.union(branch_visited) visited = visited.union(branch_visited)
else: else:
...@@ -286,7 +286,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -286,7 +286,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.env.dev_obs_dict[handle] = visited self.env.dev_obs_dict[handle] = visited
return observation return observation
def _explore_branch(self, handle, position, direction, root_observation, depth): def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth):
""" """
Utility function to compute tree-based observations. Utility function to compute tree-based observations.
""" """
...@@ -332,26 +332,31 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -332,26 +332,31 @@ class TreeObsForRailEnv(ObservationBuilder):
# Register possible conflict # Register possible conflict
if self.predictor and num_steps < self.max_prediction_depth: if self.predictor and num_steps < self.max_prediction_depth:
int_position = coordinate_to_position(self.env.width, [position]) int_position = coordinate_to_position(self.env.width, [position])
pre_step = max(0, num_steps - 1) if tot_dist < self.max_prediction_depth:
post_step = min(self.max_prediction_depth - 1, num_steps + 1) pre_step = max(0, tot_dist - 1)
# Look for opposing paths at distance num_step post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
if int_position in np.delete(self.predicted_pos[num_steps], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[num_steps], handle) == int_position) # Look for opposing paths at distance num_step
for ca in conflicting_agent: if int_position in np.delete(self.predicted_pos[tot_dist], handle):
if direction != self.predicted_dir[num_steps][ca[0]]: conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position)
potential_conflict = 1 for ca in conflicting_agent:
# Look for opposing paths at distance num_step-1 if direction != self.predicted_dir[tot_dist][ca[0]]:
elif int_position in np.delete(self.predicted_pos[pre_step], handle): potential_conflict = 1
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) # print("Potential Conflict",position,handle,ca[0],tot_dist,depth)
for ca in conflicting_agent: # Look for opposing paths at distance num_step-1
if direction != self.predicted_dir[pre_step][ca[0]]: elif int_position in np.delete(self.predicted_pos[pre_step], handle):
potential_conflict = 1 conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
# Look for opposing paths at distance num_step+1 for ca in conflicting_agent:
elif int_position in np.delete(self.predicted_pos[post_step], handle): if direction != self.predicted_dir[pre_step][ca[0]]:
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position) potential_conflict = 1
for ca in conflicting_agent: # print("Potential Conflict", position,handle,ca[0],pre_step,depth)
if direction != self.predicted_dir[post_step][ca[0]]: # Look for opposing paths at distance num_step+1
potential_conflict = 1 elif int_position in np.delete(self.predicted_pos[post_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca[0]]:
potential_conflict = 1
# print("Potential Conflict", position,handle,ca[0],post_step,depth)
if position in self.location_has_target and position != agent.target: if position in self.location_has_target and position != agent.target:
if num_steps < other_target_encountered: if num_steps < other_target_encountered:
...@@ -395,7 +400,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -395,7 +400,7 @@ class TreeObsForRailEnv(ObservationBuilder):
direction = np.argmax(cell_transitions) direction = np.argmax(cell_transitions)
position = self._new_position(position, direction) position = self._new_position(position, direction)
num_steps += 1 num_steps += 1
tot_dist += 1
elif num_transitions > 0: elif num_transitions > 0:
# Switch detected # Switch detected
last_isSwitch = True last_isSwitch = True
...@@ -499,7 +504,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -499,7 +504,7 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_observation, branch_visited = self._explore_branch(handle, branch_observation, branch_visited = self._explore_branch(handle,
new_cell, new_cell,
(branch_direction + 2) % 4, (branch_direction + 2) % 4,
new_root_observation, new_root_observation, tot_dist + 1,
depth + 1) depth + 1)
observation = observation + branch_observation observation = observation + branch_observation
if len(branch_visited) != 0: if len(branch_visited) != 0:
...@@ -509,7 +514,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -509,7 +514,7 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_observation, branch_visited = self._explore_branch(handle, branch_observation, branch_visited = self._explore_branch(handle,
new_cell, new_cell,
branch_direction, branch_direction,
new_root_observation, new_root_observation, tot_dist + 1,
depth + 1) depth + 1)
observation = observation + branch_observation observation = observation + branch_observation
if len(branch_visited) != 0: if len(branch_visited) != 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment