diff --git a/changelog.md b/changelog.md
index 91e1ab6181ad1c7cd5b28e9ee06711ecd837324a..cad8ee000506dea3abf05dedb27a43aaaf0bf8b7 100644
--- a/changelog.md
+++ b/changelog.md
@@ -1,9 +1,17 @@
 Changelog
 ==========
 
-Changes since Flatland 0.3
+Changes since Flatland 2.0.0
 --------------------------
 
+### Changes in `Environment`
+- moving of member variable `distance_map_computed` to new class `DistanceMap`
+
+### Changes in rail generator and `RailEnv`
+- renaming of `distance_maps` into `distance_map`
+
+Changes since Flatland 1.0.0
+--------------------------
 ### Changes in stock predictors
 The stock `ShortestPathPredictorForRailEnv` now respects the different agent speeds and updates their prediction accordingly.
 
@@ -25,12 +33,12 @@ The stock `ShortestPathPredictorForRailEnv` now respects the different agent spe
 ### Changes in level generation
 
 
-- Separation of `schedule_generator` from `rail_generator`: 
+- Separation of `schedule_generator` from `rail_generator`:
   - Renaming of `flatland/envs/generators.py` to `flatland/envs/rail_generators.py`
   - `rail_generator` now only returns the grid and optionally hints (a python dictionary); the hints are currently use for distance_map and communication of start and goal position in complex rail generator.
   - `schedule_generator` takes a `GridTransitionMap` and the number of agents and optionally the `agents_hints` field of the hints dictionary.
-  - Inrodcution of types hints: 
-``` 
+  - Inrodcution of types hints:
+```
 RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
 AgentPosition = Tuple[int, int]
diff --git a/docs/intro_observationbuilder.rst b/docs/intro_observationbuilder.rst
index c94a2ce18e80c9b6fbdf34b824dfc2062374c82d..4386f9e07df07693d2a4e72fa1920e4275972c67 100644
--- a/docs/intro_observationbuilder.rst
+++ b/docs/intro_observationbuilder.rst
@@ -63,7 +63,7 @@ cells and orientations to the target of each agent, i.e. a distance map for each
 
 In this example we exploit these distance maps by implementing an observation builder that shows the current shortest path for each agent as a one-hot observation vector of length 3, whose components represent the possible directions an agent can take (LEFT, FORWARD, RIGHT). All values of the observation vector are set to :code:`0` except for the shortest direction where it is set to :code:`1`.
 
-Using this observation with highly engineered features indicating the agent's shortest path, an agent can then learn to take the corresponding action at each time-step; or we could even hardcode the optimal policy. 
+Using this observation with highly engineered features indicating the agent's shortest path, an agent can then learn to take the corresponding action at each time-step; or we could even hardcode the optimal policy.
 Note that this simple strategy fails when multiple agents are present, as each agent would only attempt its greedy solution, which is not usually `Pareto-optimal <https://en.wikipedia.org/wiki/Pareto_efficiency>`_ in this context.
 
 .. _TreeObsForRailEnv: https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
@@ -71,7 +71,7 @@ Note that this simple strategy fails when multiple agents are present, as each a
 .. code-block:: python
 
     from flatland.envs.observations import TreeObsForRailEnv
-    
+
     class SingleAgentNavigationObs(TreeObsForRailEnv):
         """
         We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
@@ -84,7 +84,7 @@ Note that this simple strategy fails when multiple agents are present, as each a
         """
         def __init__(self):
             super().__init__(max_depth=0)
-            # We set max_depth=0 in because we only need to look at the current 
+            # We set max_depth=0 in because we only need to look at the current
             # position of the agent to decide what direction is shortest.
             self.observation_space = [3]
 
@@ -110,7 +110,7 @@ Note that this simple strategy fails when multiple agents are present, as each a
                 for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                     if possible_transitions[direction]:
                         new_position = self._new_position(agent.position, direction)
-                        min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
+                        min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                     else:
                         min_distances.append(np.inf)
 
@@ -175,28 +175,28 @@ In contrast to the previous examples we also implement the :code:`def get_many(s
 .. _example: https://gitlab.aicrowd.com/flatland/flatland/blob/master/examples/custom_observation_example.py#L110
 
 .. code-block:: python
-    
+
     class ObservePredictions(TreeObsForRailEnv):
         """
         We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation.
-    
+
         We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
         the minimum distances from each grid node to each agent's target.
-    
+
         This is necessary so that we can pass the distance map to the ShortestPathPredictor
-    
+
         Here we also want to highlight how you can visualize your observation
         """
-    
+
         def __init__(self, predictor):
             super().__init__(max_depth=0)
             self.observation_space = [10]
             self.predictor = predictor
-    
+
         def reset(self):
             # Recompute the distance map, if the environment has changed.
             super().reset()
-    
+
         def get_many(self, handles=None):
             '''
             Because we do not want to call the predictor seperately for every agent we implement the get_many function
@@ -204,9 +204,9 @@ In contrast to the previous examples we also implement the :code:`def get_many(s
             :param handles:
             :return:
             '''
-    
-            self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
-    
+
+            self.predictions = self.predictor.get()
+
             self.predicted_pos = {}
             for t in range(len(self.predictions[0])):
                 pos_list = []
@@ -215,47 +215,47 @@ In contrast to the previous examples we also implement the :code:`def get_many(s
                 # We transform (x,y) coodrinates to a single integer number for simpler comparison
                 self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
             observations = {}
-    
+
             # Collect all the different observation for all the agents
             for h in handles:
                 observations[h] = self.get(h)
             return observations
-    
+
         def get(self, handle):
             '''
             Lets write a simple observation which just indicates whether or not the own predicted path
             overlaps with other predicted paths at any time. This is useless for the task of navigation but might
             help when looking for conflicts. A more complex implementation can be found in the TreeObsForRailEnv class
-    
+
             Each agent recieves an observation of length 10, where each element represents a prediction step and its value
             is:
              - 0 if no overlap is happening
              - 1 where n i the number of other paths crossing the predicted cell
-    
+
             :param handle: handeled as an index of an agent
             :return: Observation of handle
             '''
-    
+
             observation = np.zeros(10)
-    
+
             # We are going to track what cells where considered while building the obervation and make them accesible
             # For rendering
-    
+
             visited = set()
             for _idx in range(10):
                 # Check if any of the other prediction overlap with agents own predictions
                 x_coord = self.predictions[handle][_idx][1]
                 y_coord = self.predictions[handle][_idx][2]
-    
+
                 # We add every observed cell to the observation rendering
                 visited.add((x_coord, y_coord))
                 if self.predicted_pos[_idx][handle] in np.delete(self.predicted_pos[_idx], handle, 0):
                     # We detect if another agent is predicting to pass through the same cell at the same predicted time
                     observation[handle] = 1
-    
+
             # This variable will be access by the renderer to visualize the observation
             self.env.dev_obs_dict[handle] = visited
-    
+
             return observation
 
 We can then use this new observation builder and the renderer to visualize the observation of each agent.
@@ -265,23 +265,23 @@ We can then use this new observation builder and the renderer to visualize the o
 
     # Initiate the Predictor
     CustomPredictor = ShortestPathPredictorForRailEnv(10)
-    
+
     # Pass the Predictor to the observation builder
     CustomObsBuilder = ObservePredictions(CustomPredictor)
-    
+
     # Initiate Environment
     env = RailEnv(width=10,
                   height=10,
                   rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
                   number_of_agents=3,
                   obs_builder_object=CustomObsBuilder)
-    
+
     obs = env.reset()
     env_renderer = RenderTool(env, gl="PILSVG")
-    
+
     # We render the initial step and show the obsered cells as colored boxes
     env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
-    
+
     action_dict = {}
     for step in range(100):
         for a in range(env.get_num_agents()):
@@ -321,7 +321,7 @@ These two objects can be used for example to detect switches that are usable by
 
     cell_transitions = self.env.rail.get_transitions(*position, direction)
     transition_bit = bin(self.env.rail.get_full_transitions(*position))
-    
+
     total_transitions = transition_bit.count("1")
     num_transitions = np.count_nonzero(cell_transitions)
 
@@ -357,7 +357,7 @@ Beyond the basic agent information we can also access more details about the age
 
 Similar to the speed data you can also access individual data about the malfunctions of an agent. All data is available through :code:`agent.malfunction_data` with:
 
-- Indication how long the agent is still malfunctioning :code:`'malfunction'` by an integer counting down at each time step. 0 means the agent is ok and can move. 
+- Indication how long the agent is still malfunctioning :code:`'malfunction'` by an integer counting down at each time step. 0 means the agent is ok and can move.
 - Possion rate at which malfunctions happen for this agent :code:`'malfunction_rate'`
 - Number of steps untill next malfunction will occur :code:`'next_malfunction'`
 - Number of malfunctions an agent have occured for this agent so far :code:`nr_malfunctions'`
diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
index b16a4e3e5a6378418b120f691248097fcdd82cb8..317372da390693dd51b53d411c4d5615582183b0 100644
--- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py
+++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
@@ -5,6 +5,7 @@ import time
 
 import numpy as np
 
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
@@ -49,8 +50,8 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = self._new_position(agent.position, direction)
-                    min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
+                    new_position = get_new_position(agent.position, direction)
+                    min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
 
diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index 00a3d6252ce6d93754c3bfd9c629ad78d45f348d..b6027184c631c79c15f967e8867b8b5f3c8ba0f6 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -46,7 +46,7 @@ class ObservePredictions(TreeObsForRailEnv):
         :return:
         '''
 
-        self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
+        self.predictions = self.predictor.get()
 
         self.predicted_pos = {}
         for t in range(len(self.predictions[0])):
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 50ea74b84ac9851e88e48bcd32b914e69bc7dd34..8aef94c23311fc693c229924953164afb5fec8ab 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -3,6 +3,7 @@ import time
 
 import numpy as np
 
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
@@ -47,8 +48,8 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = self._new_position(agent.position, direction)
-                    min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
+                    new_position = get_new_position(agent.position, direction)
+                    min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
 
diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index 86deebe54bd0248eead6716cb5a599cbb34f4d0c..b4a2c287882de6b6cf0319bfb6b97b3d4e2ba6e5 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -28,7 +28,7 @@ class PredictionBuilder:
         """
         pass
 
-    def get(self, custom_args=None, handle=0):
+    def get(self, handle=0):
         """
         Called whenever get_many in the observation build is called.
 
diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..940193198a18a63f669d6a50a04cbd8f67740a32
--- /dev/null
+++ b/flatland/envs/distance_map.py
@@ -0,0 +1,150 @@
+from collections import deque
+from typing import List, Optional
+
+import numpy as np
+
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_utils import EnvAgent
+
+
+class DistanceMap:
+    def __init__(self, agents: List[EnvAgent], env_height: int, env_width: int):
+        self.env_height = env_height
+        self.env_width = env_width
+        self.distance_map = None
+        self.agents_previous_computation = None
+        self.reset_was_called = False
+        self.agents: List[EnvAgent] = agents
+        self.rail: Optional[GridTransitionMap] = None
+
+    """
+    Set the distance map
+    """
+    def set(self, distance_map: np.ndarray):
+        self.distance_map = distance_map
+
+    """
+    Get the distance map
+    """
+    def get(self) -> np.ndarray:
+
+        if self.reset_was_called:
+            self.reset_was_called = False
+
+            nb_agents = len(self.agents)
+            compute_distance_map = True
+            if self.agents_previous_computation is not None and nb_agents == len(self.agents_previous_computation):
+                compute_distance_map = False
+                for i in range(nb_agents):
+                    if self.agents[i].target != self.agents_previous_computation[i].target:
+                        compute_distance_map = True
+            # Don't compute the distance map if it was loaded
+            if self.agents_previous_computation is None and self.distance_map is not None:
+                compute_distance_map = False
+
+            if compute_distance_map:
+                self._compute(self.agents, self.rail)
+
+        elif self.distance_map is None:
+            self._compute(self.agents, self.rail)
+
+        return self.distance_map
+
+    """
+    Reset the distance map
+    """
+    def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
+        self.reset_was_called = True
+        self.agents = agents
+        self.rail = rail
+
+    def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
+        self.agents_previous_computation = self.agents
+        self.distance_map = np.inf * np.ones(shape=(len(agents),
+                                                    self.env_height,
+                                                    self.env_width,
+                                                    4))
+        for i, agent in enumerate(agents):
+            self._distance_map_walker(rail, agent.target, i)
+
+    def _distance_map_walker(self, rail: GridTransitionMap, position, target_nr: int):
+        """
+        Utility function to compute distance maps from each cell in the rail network (and each possible
+        orientation within it) to each agent's target cell.
+        """
+        # Returns max distance to target, from the farthest away node, while filling in distance_map
+        self.distance_map[target_nr, position[0], position[1], :] = 0
+
+        # Fill in the (up to) 4 neighboring nodes
+        # direction is the direction of movement, meaning that at least a possible orientation of an agent
+        # in cell (row,col) allows a movement in direction `direction'
+        nodes_queue = deque(self._get_and_update_neighbors(rail, position, target_nr, 0, enforce_target_direction=-1))
+
+        # BFS from target `position' to all the reachable nodes in the grid
+        # Stop the search if the target position is re-visited, in any direction
+        visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
+                   (position[0], position[1], 3)}
+
+        max_distance = 0
+
+        while nodes_queue:
+            node = nodes_queue.popleft()
+
+            node_id = (node[0], node[1], node[2])
+
+            if node_id not in visited:
+                visited.add(node_id)
+
+                # From the list of possible neighbors that have at least a path to the current node, only keep those
+                # whose new orientation in the current cell would allow a transition to direction node[2]
+                valid_neighbors = self._get_and_update_neighbors(rail, (node[0], node[1]), target_nr, node[3], node[2])
+
+                for n in valid_neighbors:
+                    nodes_queue.append(n)
+
+                if len(valid_neighbors) > 0:
+                    max_distance = max(max_distance, node[3] + 1)
+
+        return max_distance
+
+    def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance, enforce_target_direction=-1):
+        """
+        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
+        minimum distances from each target cell.
+        """
+        neighbors = []
+
+        possible_directions = [0, 1, 2, 3]
+        if enforce_target_direction >= 0:
+            # The agent must land into the current cell with orientation `enforce_target_direction'.
+            # This is only possible if the agent has arrived from the cell in the opposite direction!
+            possible_directions = [(enforce_target_direction + 2) % 4]
+
+        for neigh_direction in possible_directions:
+            new_cell = get_new_position(position, neigh_direction)
+
+            if new_cell[0] >= 0 and new_cell[0] < self.env_height and new_cell[1] >= 0 and new_cell[1] < self.env_width:
+
+                desired_movement_from_new_cell = (neigh_direction + 2) % 4
+
+                # Check all possible transitions in new_cell
+                for agent_orientation in range(4):
+                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
+                    is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
+                                                            desired_movement_from_new_cell)
+                    # is_valid = True
+
+                    if is_valid:
+                        """
+                        # TODO: check that it works with deadends! -- still bugged!
+                        movement = desired_movement_from_new_cell
+                        if isNextCellDeadEnd:
+                            movement = (desired_movement_from_new_cell+2) % 4
+                        """
+                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
+                                           current_distance + 1)
+                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
+                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
+
+        return neighbors
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index ca9f4633a0b3d7a445ba54f753527d3a2bcc0cd3..58e56d7e68b8352b58aacdd025c67613221a81a8 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -2,12 +2,12 @@
 Collection of environment-specific ObservationBuilder.
 """
 import pprint
-from collections import deque
 
 import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
 
 
@@ -36,137 +36,12 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
         self.predictor = predictor
-        self.agents_previous_reset = None
+        self.location_has_target = None
         self.tree_explored_actions = [1, 2, 3, 0]
         self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
-        self.distance_map = None
-        self.distance_map_computed = False
 
     def reset(self):
-        agents = self.env.agents
-        nb_agents = len(agents)
-        compute_distance_map = True
-        if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
-            compute_distance_map = False
-            for i in range(nb_agents):
-                if agents[i].target != self.agents_previous_reset[i].target:
-                    compute_distance_map = True
-        # Don't compute the distance map if it was loaded
-        if self.agents_previous_reset is None and self.distance_map is not None:
-            self.location_has_target = {tuple(agent.target): 1 for agent in agents}
-            compute_distance_map = False
-
-        if compute_distance_map:
-            self._compute_distance_map()
-
-        self.agents_previous_reset = agents
-
-    def _compute_distance_map(self):
-        agents = self.env.agents
-        # For testing only --> To assert if a distance map need to be recomputed.
-        self.distance_map_computed = True
-        nb_agents = len(agents)
-        self.distance_map = np.inf * np.ones(shape=(nb_agents,
-                                                    self.env.height,
-                                                    self.env.width,
-                                                    4))
-        self.max_dist = np.zeros(nb_agents)
-        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
-        # Update local lookup table for all agents' target locations
-        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
-
-    def _distance_map_walker(self, position, target_nr):
-        """
-        Utility function to compute distance maps from each cell in the rail network (and each possible
-        orientation within it) to each agent's target cell.
-        """
-        # Returns max distance to target, from the farthest away node, while filling in distance_map
-        self.distance_map[target_nr, position[0], position[1], :] = 0
-
-        # Fill in the (up to) 4 neighboring nodes
-        # direction is the direction of movement, meaning that at least a possible orientation of an agent
-        # in cell (row,col) allows a movement in direction `direction`
-        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
-
-        # BFS from target `position` to all the reachable nodes in the grid
-        # Stop the search if the target position is re-visited, in any direction
-        visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
-                   (position[0], position[1], 3)}
-
-        max_distance = 0
-
-        while nodes_queue:
-            node = nodes_queue.popleft()
-
-            node_id = (node[0], node[1], node[2])
-
-            if node_id not in visited:
-                visited.add(node_id)
-
-                # From the list of possible neighbors that have at least a path to the current node, only keep those
-                # whose new orientation in the current cell would allow a transition to direction node[2]
-                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
-
-                for n in valid_neighbors:
-                    nodes_queue.append(n)
-
-                if len(valid_neighbors) > 0:
-                    max_distance = max(max_distance, node[3] + 1)
-
-        return max_distance
-
-    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
-        """
-        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
-        minimum distances from each target cell.
-        """
-        neighbors = []
-
-        possible_directions = [0, 1, 2, 3]
-        if enforce_target_direction >= 0:
-            # The agent must land into the current cell with orientation `enforce_target_direction`.
-            # This is only possible if the agent has arrived from the cell in the opposite direction!
-            possible_directions = [(enforce_target_direction + 2) % 4]
-
-        for neigh_direction in possible_directions:
-            new_cell = self._new_position(position, neigh_direction)
-
-            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
-
-                desired_movement_from_new_cell = (neigh_direction + 2) % 4
-
-                # Check all possible transitions in new_cell
-                for agent_orientation in range(4):
-                    # Is a transition along movement `desired_movement_from_new_cell` to the current cell possible?
-                    is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
-                                                            desired_movement_from_new_cell)
-
-                    if is_valid:
-                        """
-                        # TODO: check that it works with deadends! -- still bugged!
-                        movement = desired_movement_from_new_cell
-                        if isNextCellDeadEnd:
-                            movement = (desired_movement_from_new_cell+2) % 4
-                        """
-                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
-                                           current_distance + 1)
-                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
-                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
-
-        return neighbors
-
-    def _new_position(self, position, movement):
-        """
-        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
-        """
-        if movement == Grid4TransitionsEnum.NORTH:
-            return (position[0] - 1, position[1])
-        elif movement == Grid4TransitionsEnum.EAST:
-            return (position[0], position[1] + 1)
-        elif movement == Grid4TransitionsEnum.SOUTH:
-            return (position[0] + 1, position[1])
-        elif movement == Grid4TransitionsEnum.WEST:
-            return (position[0], position[1] - 1)
+        self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
 
     def get_many(self, handles=None):
         """
@@ -180,7 +55,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             self.max_prediction_depth = 0
             self.predicted_pos = {}
             self.predicted_dir = {}
-            self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
+            self.predictions = self.predictor.get()
             if self.predictions:
 
                 for t in range(len(self.predictions[0])):
@@ -276,7 +151,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # Root node - current position
         # Here information about the agent itself is stored
-        observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0,
+        observation = [0, 0, 0, 0, 0, 0, self.env.distance_map.get()[(handle, *agent.position, agent.direction)], 0, 0,
                        agent.malfunction_data['malfunction'], agent.speed_data['speed']]
 
         visited = set()
@@ -291,7 +166,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
             if possible_transitions[branch_direction]:
-                new_cell = self._new_position(agent.position, branch_direction)
+                new_cell = get_new_position(agent.position, branch_direction)
                 branch_observation, branch_visited = \
                     self._explore_branch(handle, new_cell, branch_direction, 1, 1)
                 observation = observation + branch_observation
@@ -464,7 +339,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     exploring = True
                     # convert one-hot encoding to 0,1,2,3
                     direction = np.argmax(cell_transitions)
-                    position = self._new_position(position, direction)
+                    position = get_new_position(position, direction)
                     num_steps += 1
                     tot_dist += 1
             elif num_transitions > 0:
@@ -506,7 +381,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                            potential_conflict,
                            unusable_switch,
                            np.inf,
-                           self.distance_map[handle, position[0], position[1], direction],
+                           self.env.distance_map.get()[handle, position[0], position[1], direction],
                            other_agent_same_direction,
                            other_agent_opposite_direction,
                            malfunctioning_agent,
@@ -520,7 +395,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                            potential_conflict,
                            unusable_switch,
                            tot_dist,
-                           self.distance_map[handle, position[0], position[1], direction],
+                           self.env.distance_map.get()[handle, position[0], position[1], direction],
                            other_agent_same_direction,
                            other_agent_opposite_direction,
                            malfunctioning_agent,
@@ -537,7 +412,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                  (branch_direction + 2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                 # it back
-                new_cell = self._new_position(position, (branch_direction + 2) % 4)
+                new_cell = get_new_position(position, (branch_direction + 2) % 4)
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           (branch_direction + 2) % 4,
@@ -547,7 +422,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if len(branch_visited) != 0:
                     visited = visited.union(branch_visited)
             elif last_is_switch and possible_transitions[branch_direction]:
-                new_cell = self._new_position(position, branch_direction)
+                new_cell = get_new_position(position, branch_direction)
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           branch_direction,
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 4eac97cbc25c18dd372618fdf16b80ec09ccad1a..b5e6bbfd4da3865954a90544d2cfe7a453c5d780 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -17,7 +17,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def get(self, custom_args=None, handle=None):
+    def get(self, handle=None):
         """
         Called whenever get_many in the observation build is called.
 
@@ -90,7 +90,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         # Initialize with depth 20
         self.max_depth = max_depth
 
-    def get(self, custom_args=None, handle=None):
+    def get(self, handle=None):
         """
         Called whenever get_many in the observation build is called.
         Requires distance_map to extract the shortest path.
@@ -116,8 +116,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         agents = self.env.agents
         if handle:
             agents = [self.env.agents[handle]]
-        assert custom_args is not None
-        distance_map = custom_args.get('distance_map')
+        distance_map = self.env.distance_map
         assert distance_map is not None
 
         prediction_dict = {}
@@ -153,7 +152,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                     for direction in range(4):
                         if cell_transitions[direction] == 1:
                             neighbour_cell = get_new_position(agent.position, direction)
-                            target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
+                            target_dist = distance_map.get()[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
                             if target_dist < min_dist or no_dist_found:
                                 min_dist = target_dist
                                 new_direction = direction
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index dc715abd84420c950be63181925bf858a5648f9c..249e3d18122ee9c47b1634dd9bd14e2cbf010e9c 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -14,6 +14,7 @@ from flatland.core.env import Environment
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
+from flatland.envs.distance_map import DistanceMap
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_generators import random_rail_generator, RailGenerator
 from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
@@ -143,6 +144,7 @@ class RailEnv(Environment):
         file_name: you can load a pickle file. from previously saved *.pkl file
 
         """
+        super().__init__()
 
         self.rail_generator: RailGenerator = rail_generator
         self.schedule_generator: ScheduleGenerator = schedule_generator
@@ -169,6 +171,7 @@ class RailEnv(Environment):
         self.agents: List[EnvAgent] = [None] * number_of_agents  # live agents
         self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents  # static agent information
         self.num_resets = 0
+        self.distance_map = DistanceMap(self.agents, self.height, self.width)
 
         self.action_space = [1]
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
@@ -232,8 +235,8 @@ class RailEnv(Environment):
         # TODO can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
         rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
-        if optionals and 'distance_maps' in optionals:
-            self.obs_builder.distance_map = optionals['distance_maps']
+        if optionals and 'distance_map' in optionals:
+            self.distance_map.set(optionals['distance_map'])
 
         if regen_rail or self.rail is None:
             self.rail = rail
@@ -273,6 +276,7 @@ class RailEnv(Environment):
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
         self.observation_space = self.obs_builder.observation_space  # <-- change on reset?
+        self.distance_map.reset(self.agents, self.rail)
 
         # Return the new observation vectors for each agent
         return self._get_observations()
@@ -573,8 +577,8 @@ class RailEnv(Environment):
         # agents are always reset as not moving
         self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
         self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
-        if hasattr(self.obs_builder, 'distance_map') and "distance_maps" in data.keys():
-            self.obs_builder.distance_map = data["distance_maps"]
+        if "distance_map" in data.keys():
+            self.distance_map.set(data["distance_map"])
         # setup with loaded data
         self.height, self.width = self.rail.grid.shape
         self.rail.height = self.height
@@ -588,25 +592,19 @@ class RailEnv(Environment):
         msgpack.packb(grid_data, use_bin_type=True)
         msgpack.packb(agent_data, use_bin_type=True)
         msgpack.packb(agent_static_data, use_bin_type=True)
-        if hasattr(self.obs_builder, 'distance_map'):
-            distance_map_data = self.obs_builder.distance_map
-            msgpack.packb(distance_map_data, use_bin_type=True)
-            msg_data = {
-                "grid": grid_data,
-                "agents_static": agent_static_data,
-                "agents": agent_data,
-                "distance_maps": distance_map_data}
-        else:
-            msg_data = {
-                "grid": grid_data,
-                "agents_static": agent_static_data,
-                "agents": agent_data}
+        distance_map_data = self.distance_map.get()
+        msgpack.packb(distance_map_data, use_bin_type=True)
+        msg_data = {
+            "grid": grid_data,
+            "agents_static": agent_static_data,
+            "agents": agent_data,
+            "distance_map": distance_map_data}
 
         return msgpack.packb(msg_data, use_bin_type=True)
 
     def save(self, filename):
-        if hasattr(self.obs_builder, 'distance_map'):
-            if len(self.obs_builder.distance_map) > 0:
+        if self.distance_map.get() is not None:
+            if len(self.distance_map.get()) > 0:
                 with open(filename, "wb") as file_out:
                     file_out.write(self.get_full_state_dist_msg())
             else:
@@ -617,14 +615,9 @@ class RailEnv(Environment):
                 file_out.write(self.get_full_state_msg())
 
     def load(self, filename):
-        if hasattr(self.obs_builder, 'distance_map'):
-            with open(filename, "rb") as file_in:
-                load_data = file_in.read()
-                self.set_full_state_dist_msg(load_data)
-        else:
-            with open(filename, "rb") as file_in:
-                load_data = file_in.read()
-                self.set_full_state_msg(load_data)
+        with open(filename, "rb") as file_in:
+            load_data = file_in.read()
+            self.set_full_state_dist_msg(load_data)
 
     def load_pkl(self, pkl_data):
         self.set_full_state_msg(pkl_data)
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index a6d1b9c70f1f924fb63d3636a00d756a2d22c501..64ab775a33ffba977854e65f85b70bacf15e604c 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -228,10 +228,10 @@ def rail_from_file(filename) -> RailGenerator:
         grid = np.array(data[b"grid"])
         rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
         rail.grid = grid
-        if b"distance_maps" in data.keys():
-            distance_maps = data[b"distance_maps"]
-            if len(distance_maps) > 0:
-                return rail, {'distance_maps': distance_maps}
+        if b"distance_map" in data.keys():
+            distance_map = data[b"distance_map"]
+            if len(distance_map) > 0:
+                return rail, {'distance_map': distance_map}
         return [rail, None]
 
     return generator
diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py
index e5e89f76428bb881d0f72aa60aada97ab02167a5..3bed89b8ce0947c86593e2f1680ef6082f321d84 100644
--- a/tests/test_distance_map.py
+++ b/tests/test_distance_map.py
@@ -43,9 +43,8 @@ def test_walker():
 
     # reset to set agents from agents_static
     env.reset(False, False)
-    obs_builder: TreeObsForRailEnv = env.obs_builder
 
-    print(obs_builder.distance_map[(0, *[0, 1], 1)])
-    assert obs_builder.distance_map[(0, *[0, 1], 1)] == 3
-    print(obs_builder.distance_map[(0, *[0, 2], 3)])
-    assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2
+    print(env.distance_map.get()[(0, *[0, 1], 1)])
+    assert env.distance_map.get()[(0, *[0, 1], 1)] == 3
+    print(env.distance_map.get()[(0, *[0, 2], 3)])
+    assert env.distance_map.get()[(0, *[0, 2], 1)] == 2
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index 9bf7df53570aa672739d409d598447cd27d5759b..d2663916a17a70597d10e489da7aead4f8932dc4 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -4,6 +4,7 @@
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -51,7 +52,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
         shortest_distance = np.inf
 
         for exit_direction in range(4):
-            neighbour = obs_builder._new_position(agent.position, exit_direction)
+            neighbour = get_new_position(agent.position, exit_direction)
 
             if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width:
                 desired_movement_from_new_cell = (exit_direction + 2) % 4
@@ -62,7 +63,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
                     is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation),
                                                                    desired_movement_from_new_cell)
                     if is_valid:
-                        distance_to_target = obs_builder.distance_map[
+                        distance_to_target = obs_builder.env.distance_map.get()[
                             (agent.handle, *agent.position, exit_direction)]
                         print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position,
                                                                                       agent.direction,
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 09f7e5e67a15c55b5070ac8679e43ecc9a14b9da..c31494673e63a17dc07eb6d89eeb581c640b1e13 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -137,7 +137,7 @@ def test_shortest_path_predictor(rendering=False):
         input("Continue?")
 
     # compute the observations and predictions
-    distance_map = env.obs_builder.distance_map
+    distance_map = env.distance_map.get()
     assert distance_map[0, agent.position[0], agent.position[
         1], agent.direction] == 5.0, "found {} instead of {}".format(
         distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index a63e97229457d606027c88e5189d0cb680f25c9b..81b61381ed67d927cac44f4c9733d8a040903ef5 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -1,5 +1,6 @@
 import numpy as np
 
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
@@ -40,8 +41,8 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = self._new_position(agent.position, direction)
-                    min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
+                    new_position = get_new_position(agent.position, direction)
+                    min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
 
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 8b0480c887a53ade155c28aa6199db3d32f19603..4c925789e6560077d637e2a594c736df8850d00a 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -129,7 +129,7 @@ def tests_rail_from_file():
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
     env.save(file_name)
-    dist_map_shape = np.shape(env.obs_builder.distance_map)
+    dist_map_shape = np.shape(env.distance_map.get())
     # initialize agents_static
     rails_initial = env.rail.grid
     agents_initial = env.agents
@@ -148,9 +148,8 @@ def tests_rail_from_file():
     assert agents_initial == agents_loaded
 
     # Check that distance map was not recomputed
-    assert env.obs_builder.distance_map_computed is False
-    assert np.shape(env.obs_builder.distance_map) == dist_map_shape
-    assert env.obs_builder.distance_map is not None
+    assert np.shape(env.distance_map.get()) == dist_map_shape
+    assert env.distance_map.get() is not None
 
     # Test to save and load file without distance map.
 
@@ -222,6 +221,5 @@ def tests_rail_from_file():
     assert agents_initial_2 == agents_loaded_4
 
     # Check that distance map was generated with correct shape
-    assert env4.obs_builder.distance_map_computed is True
-    assert env4.obs_builder.distance_map is not None
-    assert np.shape(env4.obs_builder.distance_map) == dist_map_shape
+    assert env4.distance_map.get() is not None
+    assert np.shape(env4.distance_map.get()) == dist_map_shape