Skip to content
Snippets Groups Projects
Commit ff2ea11d authored by Erik Nygren's avatar Erik Nygren
Browse files

updated how distance from agent is measured to detect conflicts!

parent 920032da
No related branches found
No related tags found
No related merge requests found
......@@ -273,7 +273,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
self._explore_branch(handle, new_cell, branch_direction, root_observation, 0, 1)
observation = observation + branch_observation
visited = visited.union(branch_visited)
else:
......@@ -286,7 +286,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.env.dev_obs_dict[handle] = visited
return observation
def _explore_branch(self, handle, position, direction, root_observation, depth):
def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth):
"""
Utility function to compute tree-based observations.
"""
......@@ -332,26 +332,31 @@ class TreeObsForRailEnv(ObservationBuilder):
# Register possible conflict
if self.predictor and num_steps < self.max_prediction_depth:
int_position = coordinate_to_position(self.env.width, [position])
pre_step = max(0, num_steps - 1)
post_step = min(self.max_prediction_depth - 1, num_steps + 1)
# Look for opposing paths at distance num_step
if int_position in np.delete(self.predicted_pos[num_steps], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[num_steps], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[num_steps][ca[0]]:
potential_conflict = 1
# Look for opposing paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca[0]]:
potential_conflict = 1
# Look for opposing paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca[0]]:
potential_conflict = 1
if tot_dist < self.max_prediction_depth:
pre_step = max(0, tot_dist - 1)
post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
# Look for opposing paths at distance num_step
if int_position in np.delete(self.predicted_pos[tot_dist], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[tot_dist][ca[0]]:
potential_conflict = 1
# print("Potential Conflict",position,handle,ca[0],tot_dist,depth)
# Look for opposing paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca[0]]:
potential_conflict = 1
# print("Potential Conflict", position,handle,ca[0],pre_step,depth)
# Look for opposing paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca[0]]:
potential_conflict = 1
# print("Potential Conflict", position,handle,ca[0],post_step,depth)
if position in self.location_has_target and position != agent.target:
if num_steps < other_target_encountered:
......@@ -395,7 +400,7 @@ class TreeObsForRailEnv(ObservationBuilder):
direction = np.argmax(cell_transitions)
position = self._new_position(position, direction)
num_steps += 1
tot_dist += 1
elif num_transitions > 0:
# Switch detected
last_isSwitch = True
......@@ -499,7 +504,7 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
(branch_direction + 2) % 4,
new_root_observation,
new_root_observation, tot_dist + 1,
depth + 1)
observation = observation + branch_observation
if len(branch_visited) != 0:
......@@ -509,7 +514,7 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
branch_direction,
new_root_observation,
new_root_observation, tot_dist + 1,
depth + 1)
observation = observation + branch_observation
if len(branch_visited) != 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment