diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 23970a9059f15ce89b59b1cbcb9b862d248d1cd5..a7920c7da1c330494f2f37e298dd9f691378f115 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -116,7 +116,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
 for trials in range(1, n_trials + 1):
 
     # Reset environment
-    obs = env.reset()
+    obs, _ = env.reset()
     final_obs = obs.copy()
     final_obs_next = obs.copy()
     for a in range(env.get_num_agents()):
@@ -148,7 +148,7 @@ for trials in range(1, n_trials + 1):
             action_dict.update({a: action})
 
         # Environment step
-        next_obs, all_rewards, done, _ = env.step(action_dict)
+        (next_obs,_), all_rewards, done, _ = env.step(action_dict)
 
         for a in range(env.get_num_agents()):
             data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5,
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 8e4be0ba49d8405aa420d2bc8e4854f6300e3837..4fad2c01c4d353b6ba458a94d6fb10085213eb06 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -219,7 +219,8 @@ class TreeObsForRailEnv(ObservationBuilder):
             if possible_transitions[branch_direction]:
                 new_cell = self._new_position(agent.position, branch_direction)
 
-                branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
+                branch_observation, visited = self._explore_branch(handle, new_cell, branch_direction, root_observation,
+                                                                   1)
                 observation = observation + branch_observation
             else:
                 num_cells_to_fill_in = 0
@@ -228,7 +229,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     num_cells_to_fill_in += pow4
                     pow4 *= 4
                 observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
-        return observation
+        return observation, visited
 
     def _explore_branch(self, handle, position, direction, root_observation, depth):
         """
@@ -236,7 +237,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
         # [Recursive branch opened]
         if depth >= self.max_depth + 1:
-            return []
+            return [], []
 
         # Continue along direction until next switch or
         # until no transitions are possible along the current direction (i.e., dead-ends)
@@ -377,22 +378,24 @@ class TreeObsForRailEnv(ObservationBuilder):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                 # it back
                 new_cell = self._new_position(position, (branch_direction + 2) % 4)
-                branch_observation = self._explore_branch(handle,
-                                                          new_cell,
-                                                          (branch_direction + 2) % 4,
-                                                          new_root_observation,
-                                                          depth + 1)
+                branch_observation, branch_visited = self._explore_branch(handle,
+                                                                          new_cell,
+                                                                          (branch_direction + 2) % 4,
+                                                                          new_root_observation,
+                                                                          depth + 1)
                 observation = observation + branch_observation
-
+                if len(branch_visited) != 0:
+                    visited.union(branch_visited)
             elif last_isSwitch and possible_transitions[branch_direction]:
                 new_cell = self._new_position(position, branch_direction)
-                branch_observation = self._explore_branch(handle,
-                                                          new_cell,
-                                                          branch_direction,
-                                                          new_root_observation,
-                                                          depth + 1)
+                branch_observation, branch_visited = self._explore_branch(handle,
+                                                                          new_cell,
+                                                                          branch_direction,
+                                                                          new_root_observation,
+                                                                          depth + 1)
                 observation = observation + branch_observation
-
+                if len(branch_visited) != 0:
+                    visited.union(branch_visited)
             else:
                 num_cells_to_fill_in = 0
                 pow4 = 1
@@ -401,7 +404,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     pow4 *= 4
                 observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
 
-        return observation
+        return observation, visited
 
     def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
         """
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index de3401a4bd98a1f114f25d629e7e0f13a7c0337c..749e5e01aa94e8ddb59d33d5787e3bd55d05f9eb 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -315,10 +315,11 @@ class RailEnv(Environment):
 
     def _get_observations(self):
         self.obs_dict = {}
+        self.debug_obs_dict = {}
         # for handle in self.agents_handles:
         for iAgent in range(self.get_num_agents()):
-            self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
-        return self.obs_dict
+            self.obs_dict[iAgent], self.debug_obs_dict[iAgent] = self.obs_builder.get(iAgent)
+        return self.obs_dict, self.debug_obs_dict
 
     def render(self):
         # TODO: