diff --git a/docs/intro_observationbuilder.rst b/docs/intro_observationbuilder.rst
index 64e953da23870d4653707c20c222c296db9b71f6..4386f9e07df07693d2a4e72fa1920e4275972c67 100644
--- a/docs/intro_observationbuilder.rst
+++ b/docs/intro_observationbuilder.rst
@@ -110,7 +110,7 @@ Note that this simple strategy fails when multiple agents are present, as each a
                 for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                     if possible_transitions[direction]:
                         new_position = self._new_position(agent.position, direction)
-                        min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
+                        min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                     else:
                         min_distances.append(np.inf)
 
diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py
index 18e96a2b6f5c0fb2bd557989c8382a376cc71fc5..25238d42f00694c7acfefa1b832b0d672a22a6b6 100644
--- a/examples/custom_observation_example.py
+++ b/examples/custom_observation_example.py
@@ -81,7 +81,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
                     new_position = get_new_position(agent.position, direction)
-                    min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
+                    min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
 
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 1f0f89dee63026c08e6ab2b0adb7e33d069ae652..8aef94c23311fc693c229924953164afb5fec8ab 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -49,7 +49,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
                     new_position = get_new_position(agent.position, direction)
-                    min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
+                    min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
 
diff --git a/flatland/core/env.py b/flatland/core/env.py
index f1f1b270820b5ccaa9e0644eede703a62a40aad9..1bc5b6f3eba4ee4713bd3c8d6b88440006c215a5 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -45,8 +45,6 @@ class Environment:
     def __init__(self):
         self.action_space = ()
         self.observation_space = ()
-        self.distance_map_computed = False
-        self.distance_map = None
         pass
 
     def reset(self):
diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..278202067d822037861f7fea41603b476c127e78
--- /dev/null
+++ b/flatland/envs/distance_map.py
@@ -0,0 +1,124 @@
+from collections import deque
+from typing import List
+
+import numpy as np
+
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_utils import EnvAgent
+
+
+class DistanceMap:
+    def __init__(self, agents: List[EnvAgent], env_height: int, env_width: int):
+        self.env_height = env_height
+        self.env_width = env_width
+        self.distance_map = np.inf * np.ones(shape=(len(agents),
+                                                    self.env_height,
+                                                    self.env_width,
+                                                    4))
+        self.distance_map_computed = False
+
+    """
+    Set the distance map
+    """
+    def set(self, distance_map: np.array):
+        self.distance_map = distance_map
+
+    """
+    Get the distance map
+    """
+    def get(self) -> np.array:
+        return self.distance_map
+
+    """
+    Compute the distance map
+    """
+    def compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
+        self.distance_map_computed = True
+        self.distance_map = np.inf * np.ones(shape=(len(agents),
+                                                    self.env_height,
+                                                    self.env_width,
+                                                    4))
+        for i, agent in enumerate(agents):
+            self._distance_map_walker(rail, agent.target, i)
+
+    def _distance_map_walker(self, rail: GridTransitionMap, position, target_nr: int):
+        """
+        Utility function to compute distance maps from each cell in the rail network (and each possible
+        orientation within it) to each agent's target cell.
+        """
+        # Returns max distance to target, from the farthest away node, while filling in distance_map
+        self.distance_map[target_nr, position[0], position[1], :] = 0
+
+        # Fill in the (up to) 4 neighboring nodes
+        # direction is the direction of movement, meaning that at least a possible orientation of an agent
+        # in cell (row,col) allows a movement in direction `direction'
+        nodes_queue = deque(self._get_and_update_neighbors(rail, position, target_nr, 0, enforce_target_direction=-1))
+
+        # BFS from target `position' to all the reachable nodes in the grid
+        # Stop the search if the target position is re-visited, in any direction
+        visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
+                   (position[0], position[1], 3)}
+
+        max_distance = 0
+
+        while nodes_queue:
+            node = nodes_queue.popleft()
+
+            node_id = (node[0], node[1], node[2])
+
+            if node_id not in visited:
+                visited.add(node_id)
+
+                # From the list of possible neighbors that have at least a path to the current node, only keep those
+                # whose new orientation in the current cell would allow a transition to direction node[2]
+                valid_neighbors = self._get_and_update_neighbors(rail, (node[0], node[1]), target_nr, node[3], node[2])
+
+                for n in valid_neighbors:
+                    nodes_queue.append(n)
+
+                if len(valid_neighbors) > 0:
+                    max_distance = max(max_distance, node[3] + 1)
+
+        return max_distance
+
+    def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance, enforce_target_direction=-1):
+        """
+        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
+        minimum distances from each target cell.
+        """
+        neighbors = []
+
+        possible_directions = [0, 1, 2, 3]
+        if enforce_target_direction >= 0:
+            # The agent must land into the current cell with orientation `enforce_target_direction'.
+            # This is only possible if the agent has arrived from the cell in the opposite direction!
+            possible_directions = [(enforce_target_direction + 2) % 4]
+
+        for neigh_direction in possible_directions:
+            new_cell = get_new_position(position, neigh_direction)
+
+            if new_cell[0] >= 0 and new_cell[0] < self.env_height and new_cell[1] >= 0 and new_cell[1] < self.env_width:
+
+                desired_movement_from_new_cell = (neigh_direction + 2) % 4
+
+                # Check all possible transitions in new_cell
+                for agent_orientation in range(4):
+                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
+                    is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
+                                                            desired_movement_from_new_cell)
+                    # is_valid = True
+
+                    if is_valid:
+                        """
+                        # TODO: check that it works with deadends! -- still bugged!
+                        movement = desired_movement_from_new_cell
+                        if isNextCellDeadEnd:
+                            movement = (desired_movement_from_new_cell+2) % 4
+                        """
+                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
+                                           current_distance + 1)
+                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
+                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
+
+        return neighbors
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 3f398c74468ccd69c721ac8b17e75c625221fb07..ce0ce9b1985c026539f45ed29331bd4b69cb37f7 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -50,7 +50,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if agents[i].target != self.agents_previous_reset[i].target:
                     compute_distance_map = True
         # Don't compute the distance map if it was loaded
-        if self.agents_previous_reset is None and self.env.distance_map is not None:
+        if self.agents_previous_reset is None and self.env.distance_map.get() is not None:
             self.location_has_target = {tuple(agent.target): 1 for agent in agents}
             compute_distance_map = False
 
@@ -167,7 +167,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # Root node - current position
         # Here information about the agent itself is stored
-        observation = [0, 0, 0, 0, 0, 0, self.env.distance_map[(handle, *agent.position, agent.direction)], 0, 0,
+        observation = [0, 0, 0, 0, 0, 0, self.env.distance_map.get()[(handle, *agent.position, agent.direction)], 0, 0,
                        agent.malfunction_data['malfunction'], agent.speed_data['speed']]
 
         visited = set()
@@ -397,7 +397,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                            potential_conflict,
                            unusable_switch,
                            np.inf,
-                           self.env.distance_map[handle, position[0], position[1], direction],
+                           self.env.distance_map.get()[handle, position[0], position[1], direction],
                            other_agent_same_direction,
                            other_agent_opposite_direction,
                            malfunctioning_agent,
@@ -411,7 +411,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                            potential_conflict,
                            unusable_switch,
                            tot_dist,
-                           self.env.distance_map[handle, position[0], position[1], direction],
+                           self.env.distance_map.get()[handle, position[0], position[1], direction],
                            other_agent_same_direction,
                            other_agent_opposite_direction,
                            malfunctioning_agent,
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index dc5a13b3cf51411e7f519bf2e3ee67c30c43df2e..c77b57871fcf30b231132c364c115a38d6de3889 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -148,7 +148,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                     for direction in range(4):
                         if cell_transitions[direction] == 1:
                             neighbour_cell = get_new_position(agent.position, direction)
-                            target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
+                            target_dist = distance_map.get()[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
                             if target_dist < min_dist or no_dist_found:
                                 min_dist = target_dist
                                 new_direction = direction
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 8265dffa5b5a6998a740202d1d0a7d7431b1c518..0dfd4535a91899835affb4f00e2444b42a3e77af 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -5,7 +5,6 @@ Definition of the RailEnv environment.
 import warnings
 from enum import IntEnum
 from typing import List
-from collections import deque
 
 import msgpack
 import msgpack_numpy as m
@@ -15,6 +14,7 @@ from flatland.core.env import Environment
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
+from flatland.envs.distance_map import DistanceMap
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_generators import random_rail_generator, RailGenerator
 from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
@@ -171,6 +171,7 @@ class RailEnv(Environment):
         self.agents: List[EnvAgent] = [None] * number_of_agents  # live agents
         self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents  # static agent information
         self.num_resets = 0
+        self.distance_map = DistanceMap(self.agents, self.height, self.width)
 
         self.action_space = [1]
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
@@ -235,7 +236,7 @@ class RailEnv(Environment):
         rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
         if optionals and 'distance_map' in optionals:
-            self.distance_map = optionals['distance_map']
+            self.distance_map.set(optionals['distance_map'])
 
         if regen_rail or self.rail is None:
             self.rail = rail
@@ -576,7 +577,7 @@ class RailEnv(Environment):
         self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
         self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
         if "distance_map" in data.keys():
-            self.distance_map = data["distance_map"]
+            self.distance_map.set(data["distance_map"])
         # setup with loaded data
         self.height, self.width = self.rail.grid.shape
         self.rail.height = self.height
@@ -590,7 +591,7 @@ class RailEnv(Environment):
         msgpack.packb(grid_data, use_bin_type=True)
         msgpack.packb(agent_data, use_bin_type=True)
         msgpack.packb(agent_static_data, use_bin_type=True)
-        distance_map_data = self.distance_map
+        distance_map_data = self.distance_map.get()
         msgpack.packb(distance_map_data, use_bin_type=True)
         msg_data = {
             "grid": grid_data,
@@ -601,8 +602,8 @@ class RailEnv(Environment):
         return msgpack.packb(msg_data, use_bin_type=True)
 
     def save(self, filename):
-        if self.distance_map is not None:
-            if len(self.distance_map) > 0:
+        if self.distance_map.get() is not None:
+            if len(self.distance_map.get()) > 0:
                 with open(filename, "wb") as file_out:
                     file_out.write(self.get_full_state_dist_msg())
             else:
@@ -626,95 +627,7 @@ class RailEnv(Environment):
         self.set_full_state_msg(load_data)
 
     def compute_distance_map(self):
-        agents = self.agents
-        # For testing only --> To assert if a distance map need to be recomputed.
-        self.distance_map_computed = True
-        nb_agents = len(agents)
-        self.distance_map = np.inf * np.ones(shape=(nb_agents,
-                                                    self.height,
-                                                    self.width,
-                                                    4))
-        max_dist = np.zeros(nb_agents)
-        max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
+        self.distance_map.compute(self.agents, self.rail)
         # Update local lookup table for all agents' target locations
-        self.obs_builder.location_has_target = {tuple(agent.target): 1 for agent in agents}
+        self.obs_builder.location_has_target = {tuple(agent.target): 1 for agent in self.agents}
 
-    def _distance_map_walker(self, position, target_nr):
-        """
-        Utility function to compute distance maps from each cell in the rail network (and each possible
-        orientation within it) to each agent's target cell.
-        """
-        # Returns max distance to target, from the farthest away node, while filling in distance_map
-        self.distance_map[target_nr, position[0], position[1], :] = 0
-
-        # Fill in the (up to) 4 neighboring nodes
-        # direction is the direction of movement, meaning that at least a possible orientation of an agent
-        # in cell (row,col) allows a movement in direction `direction'
-        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
-
-        # BFS from target `position' to all the reachable nodes in the grid
-        # Stop the search if the target position is re-visited, in any direction
-        visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
-                   (position[0], position[1], 3)}
-
-        max_distance = 0
-
-        while nodes_queue:
-            node = nodes_queue.popleft()
-
-            node_id = (node[0], node[1], node[2])
-
-            if node_id not in visited:
-                visited.add(node_id)
-
-                # From the list of possible neighbors that have at least a path to the current node, only keep those
-                # whose new orientation in the current cell would allow a transition to direction node[2]
-                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
-
-                for n in valid_neighbors:
-                    nodes_queue.append(n)
-
-                if len(valid_neighbors) > 0:
-                    max_distance = max(max_distance, node[3] + 1)
-
-        return max_distance
-
-    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
-        """
-        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
-        minimum distances from each target cell.
-        """
-        neighbors = []
-
-        possible_directions = [0, 1, 2, 3]
-        if enforce_target_direction >= 0:
-            # The agent must land into the current cell with orientation `enforce_target_direction'.
-            # This is only possible if the agent has arrived from the cell in the opposite direction!
-            possible_directions = [(enforce_target_direction + 2) % 4]
-
-        for neigh_direction in possible_directions:
-            new_cell = get_new_position(position, neigh_direction)
-
-            if new_cell[0] >= 0 and new_cell[0] < self.height and new_cell[1] >= 0 and new_cell[1] < self.width:
-
-                desired_movement_from_new_cell = (neigh_direction + 2) % 4
-
-                # Check all possible transitions in new_cell
-                for agent_orientation in range(4):
-                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
-                    is_valid = self.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
-                                                            desired_movement_from_new_cell)
-
-                    if is_valid:
-                        """
-                        # TODO: check that it works with deadends! -- still bugged!
-                        movement = desired_movement_from_new_cell
-                        if isNextCellDeadEnd:
-                            movement = (desired_movement_from_new_cell+2) % 4
-                        """
-                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
-                                           current_distance + 1)
-                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
-                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
-
-        return neighbors
diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py
index 2653aaec9db483903988a09aa4f69314ba7967c4..3bed89b8ce0947c86593e2f1680ef6082f321d84 100644
--- a/tests/test_distance_map.py
+++ b/tests/test_distance_map.py
@@ -44,7 +44,7 @@ def test_walker():
     # reset to set agents from agents_static
     env.reset(False, False)
 
-    print(env.distance_map[(0, *[0, 1], 1)])
-    assert env.distance_map[(0, *[0, 1], 1)] == 3
-    print(env.distance_map[(0, *[0, 2], 3)])
-    assert env.distance_map[(0, *[0, 2], 1)] == 2
+    print(env.distance_map.get()[(0, *[0, 1], 1)])
+    assert env.distance_map.get()[(0, *[0, 1], 1)] == 3
+    print(env.distance_map.get()[(0, *[0, 2], 3)])
+    assert env.distance_map.get()[(0, *[0, 2], 1)] == 2
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index 5d35bfb1f45edbc04748c6a8bbab0ac3e0fbea2a..46000de429092d3fe4effe87382d1e12bc2c3401 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -63,7 +63,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
                     is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation),
                                                                    desired_movement_from_new_cell)
                     if is_valid:
-                        distance_to_target = obs_builder.env.distance_map[
+                        distance_to_target = obs_builder.env.distance_map.get()[
                             (agent.handle, *agent.position, exit_direction)]
                         print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position,
                                                                                       agent.direction,
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 0221bf6d31c9505607250e43ca49c4e34787ed70..c31494673e63a17dc07eb6d89eeb581c640b1e13 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -137,7 +137,7 @@ def test_shortest_path_predictor(rendering=False):
         input("Continue?")
 
     # compute the observations and predictions
-    distance_map = env.distance_map
+    distance_map = env.distance_map.get()
     assert distance_map[0, agent.position[0], agent.position[
         1], agent.direction] == 5.0, "found {} instead of {}".format(
         distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index d1f487ce5e68066ba9db4f791a84ec8bc13c8416..81b61381ed67d927cac44f4c9733d8a040903ef5 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -42,7 +42,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
                     new_position = get_new_position(agent.position, direction)
-                    min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
+                    min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
 
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 43e6e720b0b993b17e11877d57478b3dbdeee6a0..8918a585e64e073a1af3efe6b1cc45ea2818d102 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -129,7 +129,7 @@ def tests_rail_from_file():
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
     env.save(file_name)
-    dist_map_shape = np.shape(env.distance_map)
+    dist_map_shape = np.shape(env.distance_map.get())
     # initialize agents_static
     rails_initial = env.rail.grid
     agents_initial = env.agents
@@ -148,9 +148,9 @@ def tests_rail_from_file():
     assert agents_initial == agents_loaded
 
     # Check that distance map was not recomputed
-    assert env.distance_map_computed is False
-    assert np.shape(env.distance_map) == dist_map_shape
-    assert env.distance_map is not None
+    assert env.distance_map.distance_map_computed is False
+    assert np.shape(env.distance_map.get()) == dist_map_shape
+    assert env.distance_map.get() is not None
 
     # Test to save and load file without distance map.
 
@@ -222,6 +222,6 @@ def tests_rail_from_file():
     assert agents_initial_2 == agents_loaded_4
 
     # Check that distance map was generated with correct shape
-    assert env4.distance_map_computed is True
-    assert env4.distance_map is not None
-    assert np.shape(env4.distance_map) == dist_map_shape
+    assert env4.distance_map.distance_map_computed is True
+    assert env4.distance_map.get() is not None
+    assert np.shape(env4.distance_map.get()) == dist_map_shape