Skip to content
Snippets Groups Projects
Commit 245c9cc2 authored by Erik Nygren's avatar Erik Nygren
Browse files

At switches now conflicts are only detected if agent has option to choose head on colliding path

parent 2d2dd61d
No related branches found
No related tags found
No related merge requests found
......@@ -74,6 +74,7 @@ class ObservationBuilder:
direction[agent.direction] = 1
return direction
class DummyObservationBuilder(ObservationBuilder):
"""
DummyObservationBuilder class which returns dummy observations
......
......@@ -270,6 +270,7 @@ class TreeObsForRailEnv(ObservationBuilder):
observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
visited = set()
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# If only one transition is possible, the tree is oriented with this transition as the forward branch.
......@@ -289,6 +290,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# add cells filled with infinity if no transition is possible
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
self.env.dev_obs_dict[handle] = visited
return observation
def _num_cells_to_fill_in(self, remaining_depth):
......@@ -362,11 +364,12 @@ class TreeObsForRailEnv(ObservationBuilder):
pre_step = max(0, tot_dist - 1)
post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
# Look for conflicting paths at distance num_step
# Look for conflicting paths at distance tot_dist
if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict:
if direction != self.predicted_dir[tot_dist][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[tot_dist][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
......@@ -375,7 +378,8 @@ class TreeObsForRailEnv(ObservationBuilder):
elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict:
if direction != self.predicted_dir[pre_step][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[pre_step][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
......@@ -384,7 +388,8 @@ class TreeObsForRailEnv(ObservationBuilder):
elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict:
if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[post_step][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
......@@ -566,6 +571,9 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.predictor:
self.predictor._set_env(self.env)
def _reverse_dir(self, direction):
return int((direction + 2) % 4)
class GlobalObsForRailEnv(ObservationBuilder):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment