diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index ec579c1dbd080dc53504421e6a58673e205f6725..9b33a55c4f5f57a8f92135d95f69696ee3df0be7 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -227,8 +227,11 @@ def rail_from_file(filename):
         agents_position = [a.position for a in agents_static]
         agents_direction = [a.direction for a in agents_static]
         agents_target = [a.target for a in agents_static]
-        return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
-
+        if len(data) > 3:
+            distance_maps = data[b"distance_maps"]
+            return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps
+        else:
+            return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
     return generator
 
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 3929b9e191615ba91fb0e13df6d8ae040b401a5a..2f8ffb80858809a7e601a250265e7336786003b2 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -22,8 +22,6 @@ class TreeObsForRailEnv(ObservationBuilder):
     For details about the features in the tree observation see the get() function.
     """
 
-    observation_dim = 9
-
     def __init__(self, max_depth, predictor=None):
         super().__init__()
         self.max_depth = max_depth
@@ -34,6 +32,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         for i in range(self.max_depth + 1):
             size += pow4
             pow4 *= 4
+        self.observation_dim = 9
         self.observation_space = [size * self.observation_dim]
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
@@ -41,22 +40,28 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.agents_previous_reset = None
         self.tree_explored_actions = [1, 2, 3, 0]
         self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
+        self.distance_map = None
 
     def reset(self):
         agents = self.env.agents
         nb_agents = len(agents)
-
         compute_distance_map = True
         if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
             compute_distance_map = False
             for i in range(nb_agents):
                 if agents[i].target != self.agents_previous_reset[i].target:
                     compute_distance_map = True
-        self.agents_previous_reset = agents
+
+        # Don't compute the distance map if it was loaded
+        if self.agents_previous_reset is None and self.distance_map is not None:
+            self.location_has_target = {tuple(agent.target): 1 for agent in agents}
+            compute_distance_map = False
 
         if compute_distance_map:
             self._compute_distance_map()
 
+        self.agents_previous_reset = agents
+
     def _compute_distance_map(self):
         agents = self.env.agents
         nb_agents = len(agents)
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 996301a84f06f185ba9ed605ea1145f404c8b16e..f082f0801ba782bab8d4106856739fc798a3efb8 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -6,6 +6,7 @@ Definition of the RailEnv environment.
 from enum import IntEnum
 
 import msgpack
+import msgpack_numpy as m
 import numpy as np
 
 from flatland.core.env import Environment
@@ -14,6 +15,8 @@ from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
 from flatland.envs.generators import random_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 
+m.patch()
+
 
 class RailEnvActions(IntEnum):
     DO_NOTHING = 0  # implies change of direction in a dead-end!
@@ -170,6 +173,10 @@ class RailEnv(Environment):
         """
         tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
+        # Check if generator provided a distance map TODO: Make this check safer!
+        if len(tRailAgents) > 5:
+            self.obs_builder.distance_map = tRailAgents[-1]
+
         if regen_rail or self.rail is None:
             self.rail = tRailAgents[0]
             self.height, self.width = self.rail.grid.shape
@@ -418,14 +425,61 @@ class RailEnv(Environment):
         self.rail.width = self.width
         self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
 
+    def set_full_state_dist_msg(self, msg_data):
+        data = msgpack.unpackb(msg_data, use_list=False)
+        self.rail.grid = np.array(data[b"grid"])
+        # agents are always reset as not moving
+        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
+        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
+        if hasattr(self.obs_builder, 'distance_map'):
+            self.obs_builder.distance_map = data[b"distance_maps"]
+        # setup with loaded data
+        self.height, self.width = self.rail.grid.shape
+        self.rail.height = self.height
+        self.rail.width = self.width
+        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
+
+    def get_full_state_dist_msg(self):
+        grid_data = self.rail.grid.tolist()
+        agent_static_data = [agent.to_list() for agent in self.agents_static]
+        agent_data = [agent.to_list() for agent in self.agents]
+
+        msgpack.packb(grid_data)
+        msgpack.packb(agent_data)
+        msgpack.packb(agent_static_data)
+        if hasattr(self.obs_builder, 'distance_map'):
+            distance_map_data = self.obs_builder.distance_map
+            msgpack.packb(distance_map_data)
+            msg_data = {
+                "grid": grid_data,
+                "agents_static": agent_static_data,
+                "agents": agent_data,
+                "distance_maps": distance_map_data}
+        else:
+            msg_data = {
+                "grid": grid_data,
+                "agents_static": agent_static_data,
+                "agents": agent_data}
+
+        return msgpack.packb(msg_data, use_bin_type=True)
+
     def save(self, filename):
-        with open(filename, "wb") as file_out:
-            file_out.write(self.get_full_state_msg())
+        if hasattr(self.obs_builder, 'distance_map'):
+            with open(filename, "wb") as file_out:
+                file_out.write(self.get_full_state_dist_msg())
+        else:
+            with open(filename, "wb") as file_out:
+                file_out.write(self.get_full_state_msg())
 
     def load(self, filename):
-        with open(filename, "rb") as file_in:
-            load_data = file_in.read()
-            self.set_full_state_msg(load_data)
+        if hasattr(self.obs_builder, 'distance_map'):
+            with open(filename, "rb") as file_in:
+                load_data = file_in.read()
+                self.set_full_state_dist_msg(load_data)
+        else:
+            with open(filename, "rb") as file_in:
+                load_data = file_in.read()
+                self.set_full_state_msg(load_data)
 
     def load_pkl(self, pkl_data):
         self.set_full_state_msg(pkl_data)
diff --git a/requirements_dev.txt b/requirements_dev.txt
index ea46eb245842881f1c51ededa5284300836a36ba..edd6ee2842dfb196db90f5334c6318f801785e0e 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -3,13 +3,14 @@ tox>=3.5.2
 twine>=1.12.1
 pytest>=3.8.2
 pytest-runner>=4.2
-numpy>=1.16.4
+numpy>=1.16.2
 recordtype>=1.3
 xarray>=0.11.3
 matplotlib>=3.0.2
 Pillow>=5.4.1
 CairoSVG>=2.3.1
 msgpack>=0.6.1
+msgpack-numpy>=0.4.4.0
 svgutils>=0.3.1
 screeninfo>=0.3.1
 pyarrow>=0.13.0
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 449b83294173c9665f54c34a668579b222f0c281..31dff253126bec041241aaba44f33dd9c494f2a1 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -5,7 +5,7 @@ import numpy as np
 
 from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
     random_rail_generator, empty_rail_generator
-from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from tests.simple_rail import make_simple_rail
@@ -109,8 +109,12 @@ def test_rail_from_grid_transition_map():
 
 
 def tests_rail_from_file():
-    file_name = "test_pkl.pkl"
+    file_name = "test_with_distance_map.pkl"
+
+    # Test to save and load file with distance map.
+
     rail, rail_map = make_simple_rail()
+
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
@@ -118,6 +122,7 @@ def tests_rail_from_file():
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
     env.save(file_name)
+
     # initialize agents_static
     rails_initial = env.rail.grid
     agents_initial = env.agents
@@ -133,4 +138,69 @@ def tests_rail_from_file():
 
     assert np.all(np.array_equal(rails_initial, rails_loaded))
     assert agents_initial == agents_loaded
+    assert env.obs_builder.distance_map is not None
+
+    # Test to save and load file without distance map.
+
+    file_name_2 = "test_without_distance_map.pkl"
+
+    env2 = RailEnv(width=rail_map.shape[1],
+                   height=rail_map.shape[0],
+                   rail_generator=rail_from_grid_transition_map(rail),
+                   number_of_agents=3,
+                   obs_builder_object=GlobalObsForRailEnv(),
+                   )
+
+    env2.save(file_name_2)
+
+    # initialize agents_static
+    rails_initial_2 = env2.rail.grid
+    agents_initial_2 = env2.agents
 
+    env2 = RailEnv(width=1,
+                   height=1,
+                   rail_generator=rail_from_file(file_name_2),
+                   number_of_agents=1,
+                   obs_builder_object=GlobalObsForRailEnv(),
+                   )
+
+    rails_loaded_2 = env2.rail.grid
+    agents_loaded_2 = env2.agents
+
+    assert np.all(np.array_equal(rails_initial_2, rails_loaded_2))
+    assert agents_initial_2 == agents_loaded_2
+    assert not hasattr(env2.obs_builder, "distance_map")
+
+    # Test to save with distance map and load without
+
+    # initialize agents_static
+    env3 = RailEnv(width=1,
+                   height=1,
+                   rail_generator=rail_from_file(file_name),
+                   number_of_agents=1,
+                   obs_builder_object=GlobalObsForRailEnv(),
+                   )
+
+    rails_loaded_3 = env3.rail.grid
+    agents_loaded_3 = env3.agents
+
+    assert np.all(np.array_equal(rails_initial, rails_loaded_3))
+    assert agents_initial == agents_loaded_3
+    assert not hasattr(env2.obs_builder, "distance_map")
+
+    # Test to save without distance map and load with generating distance map
+
+    # initialize agents_static
+    env4 = RailEnv(width=1,
+                   height=1,
+                   rail_generator=rail_from_file(file_name_2),
+                   number_of_agents=1,
+                   obs_builder_object=TreeObsForRailEnv(max_depth=2),
+                   )
+
+    rails_loaded_4 = env4.rail.grid
+    agents_loaded_4 = env4.agents
+
+    assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
+    assert agents_initial_2 == agents_loaded_4
+    assert env.obs_builder.distance_map is not None