Skip to content
Snippets Groups Projects
Commit 380e3193 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch '128_enhancing_tree_obs' into 'master'

At switches now conflicts are only detected if agent has option to choose head on colliding path

Closes #128

See merge request flatland/flatland!129
parents 2cf8c676 245c9cc2
No related branches found
No related tags found
No related merge requests found
...@@ -74,6 +74,7 @@ class ObservationBuilder: ...@@ -74,6 +74,7 @@ class ObservationBuilder:
direction[agent.direction] = 1 direction[agent.direction] = 1
return direction return direction
class DummyObservationBuilder(ObservationBuilder): class DummyObservationBuilder(ObservationBuilder):
""" """
DummyObservationBuilder class which returns dummy observations DummyObservationBuilder class which returns dummy observations
......
...@@ -270,6 +270,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -270,6 +270,7 @@ class TreeObsForRailEnv(ObservationBuilder):
observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0] observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
visited = set() visited = set()
# Start from the current orientation, and see which transitions are available; # Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation # organize them as [left, forward, right, back], relative to the current orientation
# If only one transition is possible, the tree is oriented with this transition as the forward branch. # If only one transition is possible, the tree is oriented with this transition as the forward branch.
...@@ -289,6 +290,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -289,6 +290,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# add cells filled with infinity if no transition is possible # add cells filled with infinity if no transition is possible
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth) observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
self.env.dev_obs_dict[handle] = visited self.env.dev_obs_dict[handle] = visited
return observation return observation
def _num_cells_to_fill_in(self, remaining_depth): def _num_cells_to_fill_in(self, remaining_depth):
...@@ -362,11 +364,12 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -362,11 +364,12 @@ class TreeObsForRailEnv(ObservationBuilder):
pre_step = max(0, tot_dist - 1) pre_step = max(0, tot_dist - 1)
post_step = min(self.max_prediction_depth - 1, tot_dist + 1) post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
# Look for conflicting paths at distance num_step # Look for conflicting paths at distance tot_dist
if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0): if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position) conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
for ca in conflicting_agent[0]: for ca in conflicting_agent[0]:
if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict: if direction != self.predicted_dir[tot_dist][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[tot_dist][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict: if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist potential_conflict = tot_dist
...@@ -375,7 +378,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -375,7 +378,8 @@ class TreeObsForRailEnv(ObservationBuilder):
elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent[0]: for ca in conflicting_agent[0]:
if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict: if direction != self.predicted_dir[pre_step][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[pre_step][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict: if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist potential_conflict = tot_dist
...@@ -384,7 +388,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -384,7 +388,8 @@ class TreeObsForRailEnv(ObservationBuilder):
elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent[0]: for ca in conflicting_agent[0]:
if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict: if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[post_step][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict: if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist potential_conflict = tot_dist
...@@ -566,6 +571,9 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -566,6 +571,9 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.predictor: if self.predictor:
self.predictor._set_env(self.env) self.predictor._set_env(self.env)
def _reverse_dir(self, direction):
return int((direction + 2) % 4)
class GlobalObsForRailEnv(ObservationBuilder): class GlobalObsForRailEnv(ObservationBuilder):
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment