diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 9b33a55c4f5f57a8f92135d95f69696ee3df0be7..355f5502992a34f8d58d4dbd80028eb4dd71cc48 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -145,7 +145,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
         agents_target = [sg[1] for sg in start_goal[:num_agents]]
         agents_direction = start_dir[:num_agents]
 
-        return grid_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
+        return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
     return generator
 
@@ -193,7 +193,7 @@ def rail_from_manual_specifications_generator(rail_spec):
             rail,
             num_agents)
 
-        return rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
+        return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
     return generator
 
@@ -218,6 +218,7 @@ def rail_from_file(filename):
         with open(filename, "rb") as file_in:
             load_data = file_in.read()
         data = msgpack.unpackb(load_data, use_list=False)
+
         grid = np.array(data[b"grid"])
         rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
         rail.grid = grid
@@ -227,11 +228,16 @@ def rail_from_file(filename):
         agents_position = [a.position for a in agents_static]
         agents_direction = [a.direction for a in agents_static]
         agents_target = [a.target for a in agents_static]
-        if len(data) > 3:
+        if b"distance_maps" in data.keys():
             distance_maps = data[b"distance_maps"]
-            return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps
+            if len(distance_maps) > 0:
+                return rail, agents_position, agents_direction, agents_target, [1.0] * len(
+                    agents_position), distance_maps
+            else:
+                return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
         else:
             return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+
     return generator
 
 
@@ -256,7 +262,7 @@ def rail_from_grid_transition_map(rail_map):
             rail_map,
             num_agents)
 
-        return rail_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
+        return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
     return generator
 
@@ -529,6 +535,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
             return_rail,
             num_agents)
 
-        return return_rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
+        return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
     return generator
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 38aadf939c2e2ac0b11b9fb3039817fa38786ffd..80dc73adf417b43d2740b80e9b86fe5ba5ce257f 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -41,6 +41,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.tree_explored_actions = [1, 2, 3, 0]
         self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
         self.distance_map = None
+        self.distance_map_computed = False
 
     def reset(self):
         agents = self.env.agents
@@ -51,7 +52,6 @@ class TreeObsForRailEnv(ObservationBuilder):
             for i in range(nb_agents):
                 if agents[i].target != self.agents_previous_reset[i].target:
                     compute_distance_map = True
-
         # Don't compute the distance map if it was loaded
         if self.agents_previous_reset is None and self.distance_map is not None:
             self.location_has_target = {tuple(agent.target): 1 for agent in agents}
@@ -64,6 +64,8 @@ class TreeObsForRailEnv(ObservationBuilder):
 
     def _compute_distance_map(self):
         agents = self.env.agents
+        # For testing only --> To assert if a distance map need to be recomputed.
+        self.distance_map_computed = True
         nb_agents = len(agents)
         self.distance_map = np.inf * np.ones(shape=(nb_agents,
                                                     self.env.height,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index a0aa2e64933c2f9e8d53a5bcd08bb9f8845484e6..e098160d7a14f9b7cc07b7f365e6af573c3d7d56 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -480,8 +480,12 @@ class RailEnv(Environment):
 
     def save(self, filename):
         if hasattr(self.obs_builder, 'distance_map'):
-            with open(filename, "wb") as file_out:
-                file_out.write(self.get_full_state_dist_msg())
+            if len(self.obs_builder.distance_map) > 0:
+                with open(filename, "wb") as file_out:
+                    file_out.write(self.get_full_state_dist_msg())
+            else:
+                with open(filename, "wb") as file_out:
+                    file_out.write(self.get_full_state_msg())
         else:
             with open(filename, "wb") as file_out:
                 file_out.write(self.get_full_state_msg())
diff --git a/notebooks/Scene_Editor.ipynb b/notebooks/Scene_Editor.ipynb
index 877f82001dfeef3732f824e06899f62f68ba85d4..852fddb422fd2d8e6a6c51df993a183e45511982 100644
--- a/notebooks/Scene_Editor.ipynb
+++ b/notebooks/Scene_Editor.ipynb
@@ -11,7 +11,20 @@
    "cell_type": "code",
    "execution_count": 1,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<style>.container { width:95% !important; }</style>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
    "source": [
     "from IPython.core.display import display, HTML\n",
     "display(HTML(\"<style>.container { width:95% !important; }</style>\"))"
@@ -57,7 +70,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "bd784815032c4d89803fa93399def9cf",
+       "model_id": "5a05a5f0569846dfa58fc6cfc222619a",
        "version_major": 2,
        "version_minor": 0
       },
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 31dff253126bec041241aaba44f33dd9c494f2a1..f97b071e6b33c099efa5af36766e159e57716443 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -8,7 +8,7 @@ from flatland.envs.generators import rail_from_grid_transition_map, rail_from_fi
 from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
-from tests.simple_rail import make_simple_rail
+from flatland.utils.simple_rail import make_simple_rail
 
 
 def test_empty_rail_generator():
@@ -122,7 +122,7 @@ def tests_rail_from_file():
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
     env.save(file_name)
-
+    dist_map_shape = np.shape(env.obs_builder.distance_map)
     # initialize agents_static
     rails_initial = env.rail.grid
     agents_initial = env.agents
@@ -138,6 +138,10 @@ def tests_rail_from_file():
 
     assert np.all(np.array_equal(rails_initial, rails_loaded))
     assert agents_initial == agents_loaded
+
+    # Check that distance map was not recomputed
+    assert env.obs_builder.distance_map_computed is False
+    assert np.shape(env.obs_builder.distance_map) == dist_map_shape
     assert env.obs_builder.distance_map is not None
 
     # Test to save and load file without distance map.
@@ -173,7 +177,6 @@ def tests_rail_from_file():
 
     # Test to save with distance map and load without
 
-    # initialize agents_static
     env3 = RailEnv(width=1,
                    height=1,
                    rail_generator=rail_from_file(file_name),
@@ -201,6 +204,12 @@ def tests_rail_from_file():
     rails_loaded_4 = env4.rail.grid
     agents_loaded_4 = env4.agents
 
+    # Check that no distance map was saved
+    assert not hasattr(env2.obs_builder, "distance_map")
     assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
     assert agents_initial_2 == agents_loaded_4
-    assert env.obs_builder.distance_map is not None
+
+    # Check that distance map was generated with correct shape
+    assert env4.obs_builder.distance_map_computed is True
+    assert env4.obs_builder.distance_map is not None
+    assert np.shape(env4.obs_builder.distance_map) == dist_map_shape