diff --git a/changelog.md b/changelog.md
index 91e1ab6181ad1c7cd5b28e9ee06711ecd837324a..cad8ee000506dea3abf05dedb27a43aaaf0bf8b7 100644
--- a/changelog.md
+++ b/changelog.md
@@ -1,9 +1,17 @@
 Changelog
 ==========
 
-Changes since Flatland 0.3
+Changes since Flatland 2.0.0
 --------------------------
 
+### Changes in `Environment`
+- moving of member variable `distance_map_computed` to new class `DistanceMap`
+
+### Changes in rail generator and `RailEnv`
+- renaming of `distance_maps` into `distance_map`
+
+Changes since Flatland 1.0.0
+--------------------------
 ### Changes in stock predictors
 The stock `ShortestPathPredictorForRailEnv` now respects the different agent speeds and updates their prediction accordingly.
 
@@ -25,12 +33,12 @@ The stock `ShortestPathPredictorForRailEnv` now respects the different agent spe
 ### Changes in level generation
 
 
-- Separation of `schedule_generator` from `rail_generator`: 
+- Separation of `schedule_generator` from `rail_generator`:
   - Renaming of `flatland/envs/generators.py` to `flatland/envs/rail_generators.py`
   - `rail_generator` now only returns the grid and optionally hints (a python dictionary); the hints are currently use for distance_map and communication of start and goal position in complex rail generator.
   - `schedule_generator` takes a `GridTransitionMap` and the number of agents and optionally the `agents_hints` field of the hints dictionary.
-  - Inrodcution of types hints: 
-``` 
+  - Inrodcution of types hints:
+```
 RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
 AgentPosition = Tuple[int, int]
diff --git a/docs/intro_observationbuilder.rst b/docs/intro_observationbuilder.rst
index c94a2ce18e80c9b6fbdf34b824dfc2062374c82d..4386f9e07df07693d2a4e72fa1920e4275972c67 100644
--- a/docs/intro_observationbuilder.rst
+++ b/docs/intro_observationbuilder.rst
@@ -63,7 +63,7 @@ cells and orientations to the target of each agent, i.e. a distance map for each
 
 In this example we exploit these distance maps by implementing an observation builder that shows the current shortest path for each agent as a one-hot observation vector of length 3, whose components represent the possible directions an agent can take (LEFT, FORWARD, RIGHT). All values of the observation vector are set to :code:`0` except for the shortest direction where it is set to :code:`1`.
 
-Using this observation with highly engineered features indicating the agent's shortest path, an agent can then learn to take the corresponding action at each time-step; or we could even hardcode the optimal policy. 
+Using this observation with highly engineered features indicating the agent's shortest path, an agent can then learn to take the corresponding action at each time-step; or we could even hardcode the optimal policy.
 Note that this simple strategy fails when multiple agents are present, as each agent would only attempt its greedy solution, which is not usually `Pareto-optimal <https://en.wikipedia.org/wiki/Pareto_efficiency>`_ in this context.
 
 .. _TreeObsForRailEnv: https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
@@ -71,7 +71,7 @@ Note that this simple strategy fails when multiple agents are present, as each a
 .. code-block:: python
 
     from flatland.envs.observations import TreeObsForRailEnv
-    
+
     class SingleAgentNavigationObs(TreeObsForRailEnv):
         """
         We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
@@ -84,7 +84,7 @@ Note that this simple strategy fails when multiple agents are present, as each a
         """
         def __init__(self):
             super().__init__(max_depth=0)
-            # We set max_depth=0 in because we only need to look at the current 
+            # We set max_depth=0 in because we only need to look at the current
             # position of the agent to decide what direction is shortest.
             self.observation_space = [3]
 
@@ -110,7 +110,7 @@ Note that this simple strategy fails when multiple agents are present, as each a
                 for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                     if possible_transitions[direction]:
                         new_position = self._new_position(agent.position, direction)
-                        min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
+                        min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                     else:
                         min_distances.append(np.inf)
 
@@ -175,28 +175,28 @@ In contrast to the previous examples we also implement the :code:`def get_many(s
 .. _example: https://gitlab.aicrowd.com/flatland/flatland/blob/master/examples/custom_observation_example.py#L110
 
 .. code-block:: python
-    
+
     class ObservePredictions(TreeObsForRailEnv):
         """
         We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation.
-    
+
         We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
         the minimum distances from each grid node to each agent's target.
-    
+
         This is necessary so that we can pass the distance map to the ShortestPathPredictor
-    
+
         Here we also want to highlight how you can visualize your observation
         """
-    
+
         def __init__(self, predictor):
             super().__init__(max_depth=0)
             self.observation_space = [10]
             self.predictor = predictor
-    
+
         def reset(self):
             # Recompute the distance map, if the environment has changed.
             super().reset()
-    
+
         def get_many(self, handles=None):
             '''
             Because we do not want to call the predictor seperately for every agent we implement the get_many function
@@ -204,9 +204,9 @@ In contrast to the previous examples we also implement the :code:`def get_many(s
             :param handles:
             :return:
             '''
-    
-            self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
-    
+
+            self.predictions = self.predictor.get()
+
             self.predicted_pos = {}
             for t in range(len(self.predictions[0])):
                 pos_list = []
@@ -215,47 +215,47 @@ In contrast to the previous examples we also implement the :code:`def get_many(s
                 # We transform (x,y) coodrinates to a single integer number for simpler comparison
                 self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
             observations = {}
-    
+
             # Collect all the different observation for all the agents
             for h in handles:
                 observations[h] = self.get(h)
             return observations
-    
+
         def get(self, handle):
             '''
             Lets write a simple observation which just indicates whether or not the own predicted path
             overlaps with other predicted paths at any time. This is useless for the task of navigation but might
             help when looking for conflicts. A more complex implementation can be found in the TreeObsForRailEnv class
-    
+
             Each agent recieves an observation of length 10, where each element represents a prediction step and its value
             is:
              - 0 if no overlap is happening
              - 1 where n i the number of other paths crossing the predicted cell
-    
+
             :param handle: handeled as an index of an agent
             :return: Observation of handle
             '''
-    
+
             observation = np.zeros(10)
-    
+
             # We are going to track what cells where considered while building the obervation and make them accesible
             # For rendering
-    
+
             visited = set()
             for _idx in range(10):
                 # Check if any of the other prediction overlap with agents own predictions
                 x_coord = self.predictions[handle][_idx][1]
                 y_coord = self.predictions[handle][_idx][2]
-    
+
                 # We add every observed cell to the observation rendering
                 visited.add((x_coord, y_coord))
                 if self.predicted_pos[_idx][handle] in np.delete(self.predicted_pos[_idx], handle, 0):
                     # We detect if another agent is predicting to pass through the same cell at the same predicted time
                     observation[handle] = 1
-    
+
             # This variable will be access by the renderer to visualize the observation
             self.env.dev_obs_dict[handle] = visited
-    
+
             return observation
 
 We can then use this new observation builder and the renderer to visualize the observation of each agent.
@@ -265,23 +265,23 @@ We can then use this new observation builder and the renderer to visualize the o
 
     # Initiate the Predictor
     CustomPredictor = ShortestPathPredictorForRailEnv(10)
-    
+
     # Pass the Predictor to the observation builder
     CustomObsBuilder = ObservePredictions(CustomPredictor)
-    
+
     # Initiate Environment
     env = RailEnv(width=10,
                   height=10,
                   rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
                   number_of_agents=3,
                   obs_builder_object=CustomObsBuilder)
-    
+
     obs = env.reset()
     env_renderer = RenderTool(env, gl="PILSVG")
-    
+
     # We render the initial step and show the obsered cells as colored boxes
     env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
-    
+
     action_dict = {}
     for step in range(100):
         for a in range(env.get_num_agents()):
@@ -321,7 +321,7 @@ These two objects can be used for example to detect switches that are usable by
 
     cell_transitions = self.env.rail.get_transitions(*position, direction)
     transition_bit = bin(self.env.rail.get_full_transitions(*position))
-    
+
     total_transitions = transition_bit.count("1")
     num_transitions = np.count_nonzero(cell_transitions)
 
@@ -357,7 +357,7 @@ Beyond the basic agent information we can also access more details about the age
 
 Similar to the speed data you can also access individual data about the malfunctions of an agent. All data is available through :code:`agent.malfunction_data` with:
 
-- Indication how long the agent is still malfunctioning :code:`'malfunction'` by an integer counting down at each time step. 0 means the agent is ok and can move. 
+- Indication how long the agent is still malfunctioning :code:`'malfunction'` by an integer counting down at each time step. 0 means the agent is ok and can move.
 - Possion rate at which malfunctions happen for this agent :code:`'malfunction_rate'`
 - Number of steps untill next malfunction will occur :code:`'next_malfunction'`
 - Number of malfunctions an agent have occured for this agent so far :code:`nr_malfunctions'`
diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
index b16a4e3e5a6378418b120f691248097fcdd82cb8..317372da390693dd51b53d411c4d5615582183b0 100644
--- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py
+++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
@@ -5,6 +5,7 @@ import time
 
 import numpy as np
 
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
@@ -49,8 +50,8 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = self._new_position(agent.position, direction)
-                    min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
+                    new_position = get_new_position(agent.position, direction)
+                    min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
 
diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index 48a4084ea5395f4406a7884e03e3d84a110fc37e..9238a2af4137e37e9d79bc3c1aaade2bb987403e 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -47,7 +47,7 @@ class ObservePredictions(TreeObsForRailEnv):
         :return:
         '''
 
-        self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
+        self.predictions = self.predictor.get()
 
         self.predicted_pos = {}
         for t in range(len(self.predictions[0])):
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 50ea74b84ac9851e88e48bcd32b914e69bc7dd34..8aef94c23311fc693c229924953164afb5fec8ab 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -3,6 +3,7 @@ import time
 
 import numpy as np
 
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
@@ -47,8 +48,8 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = self._new_position(agent.position, direction)
-                    min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
+                    new_position = get_new_position(agent.position, direction)
+                    min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
 
diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index 5ce69a8110236128b9a982e2540bc79357c1ba2d..13eb38140fb25730d1817e2db7a17c350a2260c7 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -28,7 +28,7 @@ class PredictionBuilder:
         """
         pass
 
-    def get(self, custom_args=None, handle=0):
+    def get(self, handle=0):
         """
         Called whenever get_many in the observation build is called.
 
diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..940193198a18a63f669d6a50a04cbd8f67740a32
--- /dev/null
+++ b/flatland/envs/distance_map.py
@@ -0,0 +1,150 @@
+from collections import deque
+from typing import List, Optional
+
+import numpy as np
+
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_utils import EnvAgent
+
+
+class DistanceMap:
+    def __init__(self, agents: List[EnvAgent], env_height: int, env_width: int):
+        self.env_height = env_height
+        self.env_width = env_width
+        self.distance_map = None
+        self.agents_previous_computation = None
+        self.reset_was_called = False
+        self.agents: List[EnvAgent] = agents
+        self.rail: Optional[GridTransitionMap] = None
+
+    """
+    Set the distance map
+    """
+    def set(self, distance_map: np.ndarray):
+        self.distance_map = distance_map
+
+    """
+    Get the distance map
+    """
+    def get(self) -> np.ndarray:
+
+        if self.reset_was_called:
+            self.reset_was_called = False
+
+            nb_agents = len(self.agents)
+            compute_distance_map = True
+            if self.agents_previous_computation is not None and nb_agents == len(self.agents_previous_computation):
+                compute_distance_map = False
+                for i in range(nb_agents):
+                    if self.agents[i].target != self.agents_previous_computation[i].target:
+                        compute_distance_map = True
+            # Don't compute the distance map if it was loaded
+            if self.agents_previous_computation is None and self.distance_map is not None:
+                compute_distance_map = False
+
+            if compute_distance_map:
+                self._compute(self.agents, self.rail)
+
+        elif self.distance_map is None:
+            self._compute(self.agents, self.rail)
+
+        return self.distance_map
+
+    """
+    Reset the distance map
+    """
+    def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
+        self.reset_was_called = True
+        self.agents = agents
+        self.rail = rail
+
+    def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
+        self.agents_previous_computation = self.agents
+        self.distance_map = np.inf * np.ones(shape=(len(agents),
+                                                    self.env_height,
+                                                    self.env_width,
+                                                    4))
+        for i, agent in enumerate(agents):
+            self._distance_map_walker(rail, agent.target, i)
+
+    def _distance_map_walker(self, rail: GridTransitionMap, position, target_nr: int):
+        """
+        Utility function to compute distance maps from each cell in the rail network (and each possible
+        orientation within it) to each agent's target cell.
+        """
+        # Returns max distance to target, from the farthest away node, while filling in distance_map
+        self.distance_map[target_nr, position[0], position[1], :] = 0
+
+        # Fill in the (up to) 4 neighboring nodes
+        # direction is the direction of movement, meaning that at least a possible orientation of an agent
+        # in cell (row,col) allows a movement in direction `direction'
+        nodes_queue = deque(self._get_and_update_neighbors(rail, position, target_nr, 0, enforce_target_direction=-1))
+
+        # BFS from target `position' to all the reachable nodes in the grid
+        # Stop the search if the target position is re-visited, in any direction
+        visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
+                   (position[0], position[1], 3)}
+
+        max_distance = 0
+
+        while nodes_queue:
+            node = nodes_queue.popleft()
+
+            node_id = (node[0], node[1], node[2])
+
+            if node_id not in visited:
+                visited.add(node_id)
+
+                # From the list of possible neighbors that have at least a path to the current node, only keep those
+                # whose new orientation in the current cell would allow a transition to direction node[2]
+                valid_neighbors = self._get_and_update_neighbors(rail, (node[0], node[1]), target_nr, node[3], node[2])
+
+                for n in valid_neighbors:
+                    nodes_queue.append(n)
+
+                if len(valid_neighbors) > 0:
+                    max_distance = max(max_distance, node[3] + 1)
+
+        return max_distance
+
+    def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance, enforce_target_direction=-1):
+        """
+        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
+        minimum distances from each target cell.
+        """
+        neighbors = []
+
+        possible_directions = [0, 1, 2, 3]
+        if enforce_target_direction >= 0:
+            # The agent must land into the current cell with orientation `enforce_target_direction'.
+            # This is only possible if the agent has arrived from the cell in the opposite direction!
+            possible_directions = [(enforce_target_direction + 2) % 4]
+
+        for neigh_direction in possible_directions:
+            new_cell = get_new_position(position, neigh_direction)
+
+            if new_cell[0] >= 0 and new_cell[0] < self.env_height and new_cell[1] >= 0 and new_cell[1] < self.env_width:
+
+                desired_movement_from_new_cell = (neigh_direction + 2) % 4
+
+                # Check all possible transitions in new_cell
+                for agent_orientation in range(4):
+                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
+                    is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
+                                                            desired_movement_from_new_cell)
+                    # is_valid = True
+
+                    if is_valid:
+                        """
+                        # TODO: check that it works with deadends! -- still bugged!
+                        movement = desired_movement_from_new_cell
+                        if isNextCellDeadEnd:
+                            movement = (desired_movement_from_new_cell+2) % 4
+                        """
+                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
+                                           current_distance + 1)
+                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
+                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
+
+        return neighbors
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index d8308ddc9b769d7c2862273990e1d6dd36b410f4..85cd5fdc798fde4f1a5f2b75edeefce2dfa2104a 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -2,12 +2,12 @@
 Collection of environment-specific ObservationBuilder.
 """
 import pprint
-from collections import deque
 
 import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
 from flatland.utils.ordered_set import OrderedSet
 
@@ -37,137 +37,12 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
         self.predictor = predictor
-        self.agents_previous_reset = None
+        self.location_has_target = None
         self.tree_explored_actions = [1, 2, 3, 0]
         self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
-        self.distance_map = None
-        self.distance_map_computed = False
 
     def reset(self):
-        agents = self.env.agents
-        nb_agents = len(agents)
-        compute_distance_map = True
-        if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
-            compute_distance_map = False
-            for i in range(nb_agents):
-                if agents[i].target != self.agents_previous_reset[i].target:
-                    compute_distance_map = True
-        # Don't compute the distance map if it was loaded
-        if self.agents_previous_reset is None and self.distance_map is not None:
-            self.location_has_target = {tuple(agent.target): 1 for agent in agents}
-            compute_distance_map = False
-
-        if compute_distance_map:
-            self._compute_distance_map()
-
-        self.agents_previous_reset = agents
-
-    def _compute_distance_map(self):
-        agents = self.env.agents
-        # For testing only --> To assert if a distance map need to be recomputed.
-        self.distance_map_computed = True
-        nb_agents = len(agents)
-        self.distance_map = np.inf * np.ones(shape=(nb_agents,
-                                                    self.env.height,
-                                                    self.env.width,
-                                                    4))
-        self.max_dist = np.zeros(nb_agents)
-        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
-        # Update local lookup table for all agents' target locations
-        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
-
-    def _distance_map_walker(self, position, target_nr):
-        """
-        Utility function to compute distance maps from each cell in the rail network (and each possible
-        orientation within it) to each agent's target cell.
-        """
-        # Returns max distance to target, from the farthest away node, while filling in distance_map
-        self.distance_map[target_nr, position[0], position[1], :] = 0
-
-        # Fill in the (up to) 4 neighboring nodes
-        # direction is the direction of movement, meaning that at least a possible orientation of an agent
-        # in cell (row,col) allows a movement in direction `direction'
-        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
-
-        # BFS from target `position' to all the reachable nodes in the grid
-        # Stop the search if the target position is re-visited, in any direction
-        visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
-                   (position[0], position[1], 3)}
-
-        max_distance = 0
-
-        while nodes_queue:
-            node = nodes_queue.popleft()
-
-            node_id = (node[0], node[1], node[2])
-
-            if node_id not in visited:
-                visited.add(node_id)
-
-                # From the list of possible neighbors that have at least a path to the current node, only keep those
-                # whose new orientation in the current cell would allow a transition to direction node[2]
-                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
-
-                for n in valid_neighbors:
-                    nodes_queue.append(n)
-
-                if len(valid_neighbors) > 0:
-                    max_distance = max(max_distance, node[3] + 1)
-
-        return max_distance
-
-    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
-        """
-        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
-        minimum distances from each target cell.
-        """
-        neighbors = []
-
-        possible_directions = [0, 1, 2, 3]
-        if enforce_target_direction >= 0:
-            # The agent must land into the current cell with orientation `enforce_target_direction'.
-            # This is only possible if the agent has arrived from the cell in the opposite direction!
-            possible_directions = [(enforce_target_direction + 2) % 4]
-
-        for neigh_direction in possible_directions:
-            new_cell = self._new_position(position, neigh_direction)
-
-            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
-
-                desired_movement_from_new_cell = (neigh_direction + 2) % 4
-
-                # Check all possible transitions in new_cell
-                for agent_orientation in range(4):
-                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
-                    is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
-                                                            desired_movement_from_new_cell)
-
-                    if is_valid:
-                        """
-                        # TODO: check that it works with deadends! -- still bugged!
-                        movement = desired_movement_from_new_cell
-                        if isNextCellDeadEnd:
-                            movement = (desired_movement_from_new_cell+2) % 4
-                        """
-                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
-                                           current_distance + 1)
-                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
-                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
-
-        return neighbors
-
-    def _new_position(self, position, movement):
-        """
-        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
-        """
-        if movement == Grid4TransitionsEnum.NORTH:
-            return (position[0] - 1, position[1])
-        elif movement == Grid4TransitionsEnum.EAST:
-            return (position[0], position[1] + 1)
-        elif movement == Grid4TransitionsEnum.SOUTH:
-            return (position[0] + 1, position[1])
-        elif movement == Grid4TransitionsEnum.WEST:
-            return (position[0], position[1] - 1)
+        self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
 
     def get_many(self, handles=None):
         """
@@ -181,7 +56,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             self.max_prediction_depth = 0
             self.predicted_pos = {}
             self.predicted_dir = {}
-            self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
+            self.predictions = self.predictor.get()
             if self.predictions:
 
                 for t in range(len(self.predictions[0])):
@@ -277,7 +152,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # Root node - current position
         # Here information about the agent itself is stored
-        observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0,
+        observation = [0, 0, 0, 0, 0, 0, self.env.distance_map.get()[(handle, *agent.position, agent.direction)], 0, 0,
                        agent.malfunction_data['malfunction'], agent.speed_data['speed']]
 
         visited = OrderedSet()
@@ -292,7 +167,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
             if possible_transitions[branch_direction]:
-                new_cell = self._new_position(agent.position, branch_direction)
+                new_cell = get_new_position(agent.position, branch_direction)
                 branch_observation, branch_visited = \
                     self._explore_branch(handle, new_cell, branch_direction, 1, 1)
                 observation = observation + branch_observation
@@ -465,7 +340,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     exploring = True
                     # convert one-hot encoding to 0,1,2,3
                     direction = np.argmax(cell_transitions)
-                    position = self._new_position(position, direction)
+                    position = get_new_position(position, direction)
                     num_steps += 1
                     tot_dist += 1
             elif num_transitions > 0:
@@ -507,7 +382,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                            potential_conflict,
                            unusable_switch,
                            np.inf,
-                           self.distance_map[handle, position[0], position[1], direction],
+                           self.env.distance_map.get()[handle, position[0], position[1], direction],
                            other_agent_same_direction,
                            other_agent_opposite_direction,
                            malfunctioning_agent,
@@ -521,7 +396,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                            potential_conflict,
                            unusable_switch,
                            tot_dist,
-                           self.distance_map[handle, position[0], position[1], direction],
+                           self.env.distance_map.get()[handle, position[0], position[1], direction],
                            other_agent_same_direction,
                            other_agent_opposite_direction,
                            malfunctioning_agent,
@@ -538,7 +413,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                  (branch_direction + 2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                 # it back
-                new_cell = self._new_position(position, (branch_direction + 2) % 4)
+                new_cell = get_new_position(position, (branch_direction + 2) % 4)
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           (branch_direction + 2) % 4,
@@ -548,7 +423,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if len(branch_visited) != 0:
                     visited |= branch_visited
             elif last_is_switch and possible_transitions[branch_direction]:
-                new_cell = self._new_position(position, branch_direction)
+                new_cell = get_new_position(position, branch_direction)
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           branch_direction,
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 1f40fd8b5042e100f075de2b9cf42cc1d836f48d..7f03b5be43d98ba3b4d87d933c267457fd133ddd 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -18,14 +18,12 @@ class DummyPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def get(self, custom_args=None, handle=None):
+    def get(self, handle=None):
         """
         Called whenever get_many in the observation build is called.
 
         Parameters
         -------
-        custom_args: dict
-            Not used in this dummy implementation.
         handle : int (optional)
             Handle of the agent for which to compute the observation vector.
 
@@ -91,15 +89,13 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         # Initialize with depth 20
         self.max_depth = max_depth
 
-    def get(self, custom_args=None, handle=None):
+    def get(self, handle=None):
         """
         Called whenever get_many in the observation build is called.
         Requires distance_map to extract the shortest path.
 
         Parameters
         -------
-        custom_args: dict
-            - distance_map : dict
         handle : int (optional)
             Handle of the agent for which to compute the observation vector.
 
@@ -117,8 +113,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         agents = self.env.agents
         if handle:
             agents = [self.env.agents[handle]]
-        assert custom_args is not None
-        distance_map = custom_args.get('distance_map')
+        distance_map = self.env.distance_map
         assert distance_map is not None
 
         prediction_dict = {}
@@ -154,7 +149,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                     for direction in range(4):
                         if cell_transitions[direction] == 1:
                             neighbour_cell = get_new_position(agent.position, direction)
-                            target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
+                            target_dist = distance_map.get()[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
                             if target_dist < min_dist or no_dist_found:
                                 min_dist = target_dist
                                 new_direction = direction
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 270837dca21e8fbb9b9af2c4d0e64dc339c50d32..df50b813175f17384478ef743284d2231a6eee10 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -15,6 +15,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
+from flatland.envs.distance_map import DistanceMap
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_generators import random_rail_generator, RailGenerator
 from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
@@ -150,6 +151,7 @@ class RailEnv(Environment):
         file_name: you can load a pickle file. from previously saved *.pkl file
 
         """
+        super().__init__()
 
         self.rail_generator: RailGenerator = rail_generator
         self.schedule_generator: ScheduleGenerator = schedule_generator
@@ -176,6 +178,7 @@ class RailEnv(Environment):
         self.agents: List[EnvAgent] = [None] * number_of_agents  # live agents
         self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents  # static agent information
         self.num_resets = 0
+        self.distance_map = DistanceMap(self.agents, self.height, self.width)
 
         self.action_space = [1]
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
@@ -239,8 +242,8 @@ class RailEnv(Environment):
         # TODO can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
         rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
-        if optionals and 'distance_maps' in optionals:
-            self.obs_builder.distance_map = optionals['distance_maps']
+        if optionals and 'distance_map' in optionals:
+            self.distance_map.set(optionals['distance_map'])
 
         if regen_rail or self.rail is None:
             self.rail = rail
@@ -291,6 +294,7 @@ class RailEnv(Environment):
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
         self.observation_space = self.obs_builder.observation_space  # <-- change on reset?
+        self.distance_map.reset(self.agents, self.rail)
 
         # Return the new observation vectors for each agent
         return self._get_observations()
@@ -628,8 +632,8 @@ class RailEnv(Environment):
         # agents are always reset as not moving
         self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
         self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
-        if hasattr(self.obs_builder, 'distance_map') and "distance_maps" in data.keys():
-            self.obs_builder.distance_map = data["distance_maps"]
+        if "distance_map" in data.keys():
+            self.distance_map.set(data["distance_map"])
         # setup with loaded data
         self.height, self.width = self.rail.grid.shape
         self.rail.height = self.height
@@ -643,25 +647,19 @@ class RailEnv(Environment):
         msgpack.packb(grid_data, use_bin_type=True)
         msgpack.packb(agent_data, use_bin_type=True)
         msgpack.packb(agent_static_data, use_bin_type=True)
-        if hasattr(self.obs_builder, 'distance_map'):
-            distance_map_data = self.obs_builder.distance_map
-            msgpack.packb(distance_map_data, use_bin_type=True)
-            msg_data = {
-                "grid": grid_data,
-                "agents_static": agent_static_data,
-                "agents": agent_data,
-                "distance_maps": distance_map_data}
-        else:
-            msg_data = {
-                "grid": grid_data,
-                "agents_static": agent_static_data,
-                "agents": agent_data}
+        distance_map_data = self.distance_map.get()
+        msgpack.packb(distance_map_data, use_bin_type=True)
+        msg_data = {
+            "grid": grid_data,
+            "agents_static": agent_static_data,
+            "agents": agent_data,
+            "distance_map": distance_map_data}
 
         return msgpack.packb(msg_data, use_bin_type=True)
 
     def save(self, filename):
-        if hasattr(self.obs_builder, 'distance_map'):
-            if len(self.obs_builder.distance_map) > 0:
+        if self.distance_map.get() is not None:
+            if len(self.distance_map.get()) > 0:
                 with open(filename, "wb") as file_out:
                     file_out.write(self.get_full_state_dist_msg())
             else:
@@ -672,14 +670,9 @@ class RailEnv(Environment):
                 file_out.write(self.get_full_state_msg())
 
     def load(self, filename):
-        if hasattr(self.obs_builder, 'distance_map'):
-            with open(filename, "rb") as file_in:
-                load_data = file_in.read()
-                self.set_full_state_dist_msg(load_data)
-        else:
-            with open(filename, "rb") as file_in:
-                load_data = file_in.read()
-                self.set_full_state_msg(load_data)
+        with open(filename, "rb") as file_in:
+            load_data = file_in.read()
+            self.set_full_state_dist_msg(load_data)
 
     def load_pkl(self, pkl_data):
         self.set_full_state_msg(pkl_data)
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index a16fb6018a6354665a44c1b44cafd6975bb4e680..e5f0a8e8dcbe81660727e4eb04f5d4a0f636b5d4 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -226,10 +226,10 @@ def rail_from_file(filename) -> RailGenerator:
         grid = np.array(data[b"grid"])
         rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
         rail.grid = grid
-        if b"distance_maps" in data.keys():
-            distance_maps = data[b"distance_maps"]
-            if len(distance_maps) > 0:
-                return rail, {'distance_maps': distance_maps}
+        if b"distance_map" in data.keys():
+            distance_map = data[b"distance_map"]
+            if len(distance_map) > 0:
+                return rail, {'distance_map': distance_map}
         return [rail, None]
 
     return generator
diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py
index e5e89f76428bb881d0f72aa60aada97ab02167a5..3bed89b8ce0947c86593e2f1680ef6082f321d84 100644
--- a/tests/test_distance_map.py
+++ b/tests/test_distance_map.py
@@ -43,9 +43,8 @@ def test_walker():
 
     # reset to set agents from agents_static
     env.reset(False, False)
-    obs_builder: TreeObsForRailEnv = env.obs_builder
 
-    print(obs_builder.distance_map[(0, *[0, 1], 1)])
-    assert obs_builder.distance_map[(0, *[0, 1], 1)] == 3
-    print(obs_builder.distance_map[(0, *[0, 2], 3)])
-    assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2
+    print(env.distance_map.get()[(0, *[0, 1], 1)])
+    assert env.distance_map.get()[(0, *[0, 1], 1)] == 3
+    print(env.distance_map.get()[(0, *[0, 2], 3)])
+    assert env.distance_map.get()[(0, *[0, 2], 1)] == 2
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index c96e8db00fe721f42667aed4833d034a47f19156..46000de429092d3fe4effe87382d1e12bc2c3401 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -4,6 +4,7 @@
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -51,7 +52,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
         shortest_distance = np.inf
 
         for exit_direction in range(4):
-            neighbour = obs_builder._new_position(agent.position, exit_direction)
+            neighbour = get_new_position(agent.position, exit_direction)
 
             if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width:
                 desired_movement_from_new_cell = (exit_direction + 2) % 4
@@ -62,7 +63,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
                     is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation),
                                                                    desired_movement_from_new_cell)
                     if is_valid:
-                        distance_to_target = obs_builder.distance_map[
+                        distance_to_target = obs_builder.env.distance_map.get()[
                             (agent.handle, *agent.position, exit_direction)]
                         print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position,
                                                                                       agent.direction,
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 09f7e5e67a15c55b5070ac8679e43ecc9a14b9da..c31494673e63a17dc07eb6d89eeb581c640b1e13 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -137,7 +137,7 @@ def test_shortest_path_predictor(rendering=False):
         input("Continue?")
 
     # compute the observations and predictions
-    distance_map = env.obs_builder.distance_map
+    distance_map = env.distance_map.get()
     assert distance_map[0, agent.position[0], agent.position[
         1], agent.direction] == 5.0, "found {} instead of {}".format(
         distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 8b0480c887a53ade155c28aa6199db3d32f19603..4c925789e6560077d637e2a594c736df8850d00a 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -129,7 +129,7 @@ def tests_rail_from_file():
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
     env.save(file_name)
-    dist_map_shape = np.shape(env.obs_builder.distance_map)
+    dist_map_shape = np.shape(env.distance_map.get())
     # initialize agents_static
     rails_initial = env.rail.grid
     agents_initial = env.agents
@@ -148,9 +148,8 @@ def tests_rail_from_file():
     assert agents_initial == agents_loaded
 
     # Check that distance map was not recomputed
-    assert env.obs_builder.distance_map_computed is False
-    assert np.shape(env.obs_builder.distance_map) == dist_map_shape
-    assert env.obs_builder.distance_map is not None
+    assert np.shape(env.distance_map.get()) == dist_map_shape
+    assert env.distance_map.get() is not None
 
     # Test to save and load file without distance map.
 
@@ -222,6 +221,5 @@ def tests_rail_from_file():
     assert agents_initial_2 == agents_loaded_4
 
     # Check that distance map was generated with correct shape
-    assert env4.obs_builder.distance_map_computed is True
-    assert env4.obs_builder.distance_map is not None
-    assert np.shape(env4.obs_builder.distance_map) == dist_map_shape
+    assert env4.distance_map.get() is not None
+    assert np.shape(env4.distance_map.get()) == dist_map_shape