diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py
index d155042e7c045780872974c86c0655d59178b930..940193198a18a63f669d6a50a04cbd8f67740a32 100644
--- a/flatland/envs/distance_map.py
+++ b/flatland/envs/distance_map.py
@@ -1,5 +1,5 @@
 from collections import deque
-from typing import List
+from typing import List, Optional
 
 import numpy as np
 
@@ -13,8 +13,10 @@ class DistanceMap:
         self.env_height = env_height
         self.env_width = env_width
         self.distance_map = None
-        self.distance_map_computed = False
-        self.agents_previous_reset = None
+        self.agents_previous_computation = None
+        self.reset_was_called = False
+        self.agents: List[EnvAgent] = agents
+        self.rail: Optional[GridTransitionMap] = None
 
     """
     Set the distance map
@@ -26,31 +28,39 @@ class DistanceMap:
     Get the distance map
     """
     def get(self) -> np.ndarray:
-        return self.distance_map
 
-    """
-    Compute the distance map
-    """
-    def compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
+        if self.reset_was_called:
+            self.reset_was_called = False
+
+            nb_agents = len(self.agents)
+            compute_distance_map = True
+            if self.agents_previous_computation is not None and nb_agents == len(self.agents_previous_computation):
+                compute_distance_map = False
+                for i in range(nb_agents):
+                    if self.agents[i].target != self.agents_previous_computation[i].target:
+                        compute_distance_map = True
+            # Don't compute the distance map if it was loaded
+            if self.agents_previous_computation is None and self.distance_map is not None:
+                compute_distance_map = False
 
-        nb_agents = len(agents)
-        compute_distance_map = True
-        if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
-            compute_distance_map = False
-            for i in range(nb_agents):
-                if agents[i].target != self.agents_previous_reset[i].target:
-                    compute_distance_map = True
-        # Don't compute the distance map if it was loaded
-        if self.agents_previous_reset is None and self.distance_map is not None:
-            compute_distance_map = False
+            if compute_distance_map:
+                self._compute(self.agents, self.rail)
 
-        if compute_distance_map:
-            self._compute(agents, rail)
+        elif self.distance_map is None:
+            self._compute(self.agents, self.rail)
 
-        self.agents_previous_reset = agents
+        return self.distance_map
+
+    """
+    Reset the distance map
+    """
+    def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
+        self.reset_was_called = True
+        self.agents = agents
+        self.rail = rail
 
     def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
-        self.distance_map_computed = True
+        self.agents_previous_computation = self.agents
         self.distance_map = np.inf * np.ones(shape=(len(agents),
                                                     self.env_height,
                                                     self.env_width,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 2d358a29fa322539ed408c0a5c84cc3c6da986ac..f5881836072ec09a2fe4ee9a70482566867362c4 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -276,7 +276,7 @@ class RailEnv(Environment):
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
         self.observation_space = self.obs_builder.observation_space  # <-- change on reset?
-        self.distance_map.compute(self.agents, self.rail)
+        self.distance_map.reset(self.agents, self.rail)
 
         # Return the new observation vectors for each agent
         return self._get_observations()
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 8918a585e64e073a1af3efe6b1cc45ea2818d102..4c925789e6560077d637e2a594c736df8850d00a 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -148,7 +148,6 @@ def tests_rail_from_file():
     assert agents_initial == agents_loaded
 
     # Check that distance map was not recomputed
-    assert env.distance_map.distance_map_computed is False
     assert np.shape(env.distance_map.get()) == dist_map_shape
     assert env.distance_map.get() is not None
 
@@ -222,6 +221,5 @@ def tests_rail_from_file():
     assert agents_initial_2 == agents_loaded_4
 
     # Check that distance map was generated with correct shape
-    assert env4.distance_map.distance_map_computed is True
     assert env4.distance_map.get() is not None
     assert np.shape(env4.distance_map.get()) == dist_map_shape