Commit ff2ea11d authored by Erik Nygren's avatar Erik Nygren
Browse files

updated how distance from agent is measured to detect conflicts!

parent 920032da
Pipeline #1104 failed with stages
in 9 minutes and 23 seconds
......@@ -273,7 +273,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
self._explore_branch(handle, new_cell, branch_direction, root_observation, 0, 1)
observation = observation + branch_observation
visited = visited.union(branch_visited)
else:
......@@ -286,7 +286,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.env.dev_obs_dict[handle] = visited
return observation
def _explore_branch(self, handle, position, direction, root_observation, depth):
def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth):
"""
Utility function to compute tree-based observations.
"""
......@@ -332,26 +332,31 @@ class TreeObsForRailEnv(ObservationBuilder):
# Register possible conflict
if self.predictor and num_steps < self.max_prediction_depth:
int_position = coordinate_to_position(self.env.width, [position])
pre_step = max(0, num_steps - 1)
post_step = min(self.max_prediction_depth - 1, num_steps + 1)
# Look for opposing paths at distance num_step
if int_position in np.delete(self.predicted_pos[num_steps], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[num_steps], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[num_steps][ca[0]]:
potential_conflict = 1
# Look for opposing paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca[0]]:
potential_conflict = 1
# Look for opposing paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca[0]]:
potential_conflict = 1
if tot_dist < self.max_prediction_depth:
pre_step = max(0, tot_dist - 1)
post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
# Look for opposing paths at distance num_step
if int_position in np.delete(self.predicted_pos[tot_dist], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[tot_dist][ca[0]]:
potential_conflict = 1
# print("Potential Conflict",position,handle,ca[0],tot_dist,depth)
# Look for opposing paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca[0]]:
potential_conflict = 1
# print("Potential Conflict", position,handle,ca[0],pre_step,depth)
# Look for opposing paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca[0]]:
potential_conflict = 1
# print("Potential Conflict", position,handle,ca[0],post_step,depth)
if position in self.location_has_target and position != agent.target:
if num_steps < other_target_encountered:
......@@ -395,7 +400,7 @@ class TreeObsForRailEnv(ObservationBuilder):
direction = np.argmax(cell_transitions)
position = self._new_position(position, direction)
num_steps += 1
tot_dist += 1
elif num_transitions > 0:
# Switch detected
last_isSwitch = True
......@@ -499,7 +504,7 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
(branch_direction + 2) % 4,
new_root_observation,
new_root_observation, tot_dist + 1,
depth + 1)
observation = observation + branch_observation
if len(branch_visited) != 0:
......@@ -509,7 +514,7 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
branch_direction,
new_root_observation,
new_root_observation, tot_dist + 1,
depth + 1)
observation = observation + branch_observation
if len(branch_visited) != 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment