diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 029851a8de2b5d0884ff1e189058a6ddf08ab340..03879eb4cbf1072dace4dafacd36c826e6334ab2 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -80,7 +80,7 @@ class RailEnv(Environment):
                  rail_generator=random_rail_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
-                 pkl_load=None
+                 file_name=None
                  ):
         """
         Environment init.
@@ -111,7 +111,7 @@ class RailEnv(Environment):
         obs_builder_object: ObservationBuilder object
             ObservationBuilder-derived object that takes builds observation
             vectors for each agent.
-        pkl_load: you can load a pickle file.
+        file_name: you can load a pickle file.
         """
 
         self.rail_generator = rail_generator
@@ -121,6 +121,8 @@ class RailEnv(Environment):
 
         self.rewards = [0] * number_of_agents
         self.done = False
+        self.obs_builder = obs_builder_object
+        self.obs_builder._set_env(self)
 
         self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
 
@@ -131,13 +133,10 @@ class RailEnv(Environment):
         self.agents = [None] * number_of_agents  # live agents
         self.agents_static = [None] * number_of_agents  # static agent information
         self.num_resets = 0
-        if pkl_load:
-            self.loaded_data = pkl_load
+        if file_name:
+            self.loaded_file = file_name
         else:
-            self.loaded_data = None
-
-        self.obs_builder = obs_builder_object
-        self.obs_builder._set_env(self)
+            self.loaded_file = None
 
         self.action_space = [1]
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
@@ -182,8 +181,8 @@ class RailEnv(Environment):
         if replace_agents:
             self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
 
-        if self.loaded_data:
-            self.load_pkl(self.loaded_data)
+        if self.loaded_file:
+            self.load(self.loaded_file)
 
         self.restart_agents()
 
diff --git a/tests/test_file_load.py b/tests/test_file_load.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bc326047edb4730bd988c32c92c3a2460543eb4
--- /dev/null
+++ b/tests/test_file_load.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from tests.simple_rail import make_simple_rail
+
+
+def test_load_pkl():
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=3,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+    env.save("test_pkl.pkl")
+    # initialize agents_static
+    obs_0 = env.reset(False, False)
+    file_name = "test_pkl.pkl"
+
+    env = RailEnv(width=1,
+                  height=1,
+                  rail_generator=empty_rail_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  file_name=file_name
+                  )
+    obs_1 = env.reset(False, False)
+    assert obs_0 == obs_1
+    return