diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index c3f569ae0d927a6e71803a24d921337c65d39c29..907b4a25edc0d44ff83d7a063795b55e020362b3 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -1,10 +1,12 @@
+import msgpack
 import numpy as np
 
-from flatland.core.transition_map import GridTransitionMap
+from flatland.core.grid.grid4_utils import get_direction, mirror
+from flatland.core.grid.grid_utils import distance_on_rail
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_utils import EnvAgentStatic
 from flatland.envs.grid4_generators_utils import connect_rail
-from flatland.core.grid.grid_utils import distance_on_rail
-from flatland.core.grid.grid4_utils import get_direction, mirror
 from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail
 
 
@@ -195,6 +197,40 @@ def rail_from_manual_specifications_generator(rail_spec):
     return generator
 
 
+def rail_from_file(filename):
+    """
+    Utility to load pickle file
+
+    Parameters
+    -------
+    input_file : Pickle file generated by env.save() or editor
+
+    Returns
+    -------
+    function
+        Generator function that always returns a GridTransitionMap object with
+        the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
+    """
+
+    def generator(width, height, num_agents, num_resets):
+        rail_env_transitions = RailEnvTransitions()
+        with open(filename, "rb") as file_in:
+            load_data = file_in.read()
+        data = msgpack.unpackb(load_data, use_list=False)
+        grid = np.array(data[b"grid"])
+        rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
+        rail.grid = grid
+        # agents are always reset as not moving
+        agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
+        # setup with loaded data
+        agents_position = [a.position for a in agents_static]
+        agents_direction = [a.direction for a in agents_static]
+        agents_target = [a.target for a in agents_static]
+        return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+
+    return generator
+
+
 def rail_from_GridTransitionMap_generator(rail_map):
     """
     Utility to convert a rail given by a GridTransitionMap map with the correct
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 8fc70ce37390ed2bde29808f97de5522d877cdf2..3929b9e191615ba91fb0e13df6d8ae040b401a5a 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -75,7 +75,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         orientation within it) to each agent's target cell.
         """
         # Returns max distance to target, from the farthest away node, while filling in distance_map
-
         self.distance_map[target_nr, position[0], position[1], :] = 0
 
         # Fill in the (up to) 4 neighboring nodes
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 03879eb4cbf1072dace4dafacd36c826e6334ab2..34762c0bb1d4de1cb82ee82e139113977b7e7a7d 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -80,7 +80,6 @@ class RailEnv(Environment):
                  rail_generator=random_rail_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
-                 file_name=None
                  ):
         """
         Environment init.
@@ -133,10 +132,6 @@ class RailEnv(Environment):
         self.agents = [None] * number_of_agents  # live agents
         self.agents_static = [None] * number_of_agents  # static agent information
         self.num_resets = 0
-        if file_name:
-            self.loaded_file = file_name
-        else:
-            self.loaded_file = None
 
         self.action_space = [1]
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
@@ -177,13 +172,11 @@ class RailEnv(Environment):
 
         if regen_rail or self.rail is None:
             self.rail = tRailAgents[0]
+            self.height, self.width = self.rail.grid.shape
 
         if replace_agents:
             self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
 
-        if self.loaded_file:
-            self.load(self.loaded_file)
-
         self.restart_agents()
 
         for i_agent in range(self.get_num_agents()):
diff --git a/tests/test_file_load.py b/tests/test_file_load.py
index af5644f3ee81d72641449e7184c078830f65bdc8..57fa45cb29dab07b84f43c97a45043c9dfa39979 100644
--- a/tests/test_file_load.py
+++ b/tests/test_file_load.py
@@ -3,7 +3,7 @@
 
 import numpy as np
 
-from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator
+from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_file
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -26,10 +26,9 @@ def test_load_pkl():
 
     env = RailEnv(width=1,
                   height=1,
-                  rail_generator=empty_rail_generator(),
+                  rail_generator=rail_from_file(file_name),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  file_name=file_name
                   )
     rails_loaded = env.rail.grid
     agents_loaded = env.agents