diff --git a/changelog.md b/changelog.md
index 91e1ab6181ad1c7cd5b28e9ee06711ecd837324a..2b6123322f06ac13d5a8c503de7fdb5010714397 100644
--- a/changelog.md
+++ b/changelog.md
@@ -4,6 +4,12 @@ Changelog
 Changes since Flatland 0.3
 --------------------------
 
+### Changes in `Environment`
+- moving of member variable `distance_map_computed` to new class `DistanceMap`
+
+### Changes in rail generator and `RailEnv`
+- renaming of `distance_maps` into `distance_map`
+
 ### Changes in stock predictors
 The stock `ShortestPathPredictorForRailEnv` now respects the different agent speeds and updates their prediction accordingly.
 
@@ -25,12 +31,12 @@ The stock `ShortestPathPredictorForRailEnv` now respects the different agent spe
 ### Changes in level generation
 
 
-- Separation of `schedule_generator` from `rail_generator`: 
+- Separation of `schedule_generator` from `rail_generator`:
   - Renaming of `flatland/envs/generators.py` to `flatland/envs/rail_generators.py`
   - `rail_generator` now only returns the grid and optionally hints (a python dictionary); the hints are currently use for distance_map and communication of start and goal position in complex rail generator.
   - `schedule_generator` takes a `GridTransitionMap` and the number of agents and optionally the `agents_hints` field of the hints dictionary.
-  - Inrodcution of types hints: 
-``` 
+  - Inrodcution of types hints:
+```
 RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
 AgentPosition = Tuple[int, int]
diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py
index 278202067d822037861f7fea41603b476c127e78..d155042e7c045780872974c86c0655d59178b930 100644
--- a/flatland/envs/distance_map.py
+++ b/flatland/envs/distance_map.py
@@ -12,28 +12,44 @@ class DistanceMap:
     def __init__(self, agents: List[EnvAgent], env_height: int, env_width: int):
         self.env_height = env_height
         self.env_width = env_width
-        self.distance_map = np.inf * np.ones(shape=(len(agents),
-                                                    self.env_height,
-                                                    self.env_width,
-                                                    4))
+        self.distance_map = None
         self.distance_map_computed = False
+        self.agents_previous_reset = None
 
     """
     Set the distance map
     """
-    def set(self, distance_map: np.array):
+    def set(self, distance_map: np.ndarray):
         self.distance_map = distance_map
 
     """
     Get the distance map
     """
-    def get(self) -> np.array:
+    def get(self) -> np.ndarray:
         return self.distance_map
 
     """
     Compute the distance map
     """
     def compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
+
+        nb_agents = len(agents)
+        compute_distance_map = True
+        if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
+            compute_distance_map = False
+            for i in range(nb_agents):
+                if agents[i].target != self.agents_previous_reset[i].target:
+                    compute_distance_map = True
+        # Don't compute the distance map if it was loaded
+        if self.agents_previous_reset is None and self.distance_map is not None:
+            compute_distance_map = False
+
+        if compute_distance_map:
+            self._compute(agents, rail)
+
+        self.agents_previous_reset = agents
+
+    def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
         self.distance_map_computed = True
         self.distance_map = np.inf * np.ones(shape=(len(agents),
                                                     self.env_height,
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index ce0ce9b1985c026539f45ed29331bd4b69cb37f7..a833fc01949d4184d5ca2442c6bb429d697318f3 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -36,28 +36,12 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
         self.predictor = predictor
-        self.agents_previous_reset = None
+        self.location_has_target = None
         self.tree_explored_actions = [1, 2, 3, 0]
         self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
 
     def reset(self):
-        agents = self.env.agents
-        nb_agents = len(agents)
-        compute_distance_map = True
-        if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
-            compute_distance_map = False
-            for i in range(nb_agents):
-                if agents[i].target != self.agents_previous_reset[i].target:
-                    compute_distance_map = True
-        # Don't compute the distance map if it was loaded
-        if self.agents_previous_reset is None and self.env.distance_map.get() is not None:
-            self.location_has_target = {tuple(agent.target): 1 for agent in agents}
-            compute_distance_map = False
-
-        if compute_distance_map:
-            self.env.compute_distance_map()
-
-        self.agents_previous_reset = agents
+        self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
 
     def get_many(self, handles=None):
         """
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 0dfd4535a91899835affb4f00e2444b42a3e77af..2d358a29fa322539ed408c0a5c84cc3c6da986ac 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -276,6 +276,7 @@ class RailEnv(Environment):
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
         self.observation_space = self.obs_builder.observation_space  # <-- change on reset?
+        self.distance_map.compute(self.agents, self.rail)
 
         # Return the new observation vectors for each agent
         return self._get_observations()
@@ -625,9 +626,3 @@ class RailEnv(Environment):
         from importlib_resources import read_binary
         load_data = read_binary(package, resource)
         self.set_full_state_msg(load_data)
-
-    def compute_distance_map(self):
-        self.distance_map.compute(self.agents, self.rail)
-        # Update local lookup table for all agents' target locations
-        self.obs_builder.location_has_target = {tuple(agent.target): 1 for agent in self.agents}
-