diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py
index 41fd5a313bfb14cb1f792cc264540307104b284a..1661ef65a9a33f3b44a098caaf83317919722398 100644
--- a/examples/simple_example_3.py
+++ b/examples/simple_example_3.py
@@ -7,8 +7,8 @@ from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
-random.seed(10)
-np.random.seed(10)
+random.seed(1)
+np.random.seed(1)
 
 env = RailEnv(width=7,
               height=7,
@@ -19,7 +19,7 @@ env = RailEnv(width=7,
 # Print the observation vector for agent 0
 obs, all_rewards, done, _ = env.step({0: 0})
 for i in range(env.get_num_agents()):
-    env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5)
+    env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=7)
 
 env_renderer = RenderTool(env, gl="PIL")
 env_renderer.renderEnv(show=True, frames=True)
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 5cc3f26d0374cb07a673d05771cbbf64a66f0fc0..2214544adb9620ab1eefe6067f88c6bd3be8205d 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -31,19 +31,32 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
 
+        self.agents_previous_reset = None
+
     def reset(self):
         agents = self.env.agents
         nAgents = len(agents)
-        self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
-                                                    self.env.height,
-                                                    self.env.width,
-                                                    4))
-        self.max_dist = np.zeros(nAgents)
-
-        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
 
-        # Update local lookup table for all agents' target locations
-        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
+        compute_distance_map = True
+        if self.agents_previous_reset is not None:
+            if nAgents == len(self.agents_previous_reset):
+                compute_distance_map = False
+                for i in range(nAgents):
+                    if agents[i].target != self.agents_previous_reset[i].target:
+                        compute_distance_map = True
+        self.agents_previous_reset = agents
+
+        if compute_distance_map:
+            self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
+                                                        self.env.height,
+                                                        self.env.width,
+                                                        4))
+            self.max_dist = np.zeros(nAgents)
+
+            self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
+
+            # Update local lookup table for all agents' target locations
+            self.location_has_target = {tuple(agent.target): 1 for agent in agents}
 
     def _distance_map_walker(self, position, target_nr):
         """