diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index b1f46ec638d8a865a2ce36a9715d89b452c782f2..8fc70ce37390ed2bde29808f97de5522d877cdf2 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -172,18 +172,21 @@ class TreeObsForRailEnv(ObservationBuilder):
         if handles is None:
             handles = []
         if self.predictor:
+            self.max_prediction_depth = 0
             self.predicted_pos = {}
             self.predicted_dir = {}
             self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
-            for t in range(len(self.predictions[0])):
-                pos_list = []
-                dir_list = []
-                for a in handles:
-                    pos_list.append(self.predictions[a][t][1:3])
-                    dir_list.append(self.predictions[a][t][3])
-                self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
-                self.predicted_dir.update({t: dir_list})
-            self.max_prediction_depth = len(self.predicted_pos)
+            if self.predictions:
+
+                for t in range(len(self.predictions[0])):
+                    pos_list = []
+                    dir_list = []
+                    for a in handles:
+                        pos_list.append(self.predictions[a][t][1:3])
+                        dir_list.append(self.predictions[a][t][3])
+                    self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
+                    self.predicted_dir.update({t: dir_list})
+                self.max_prediction_depth = len(self.predicted_pos)
         observations = {}
         for h in handles:
             observations[h] = self.get(h)
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 149e32bd66bb0c0e1ac96bbfe213705b4de63253..03879eb4cbf1072dace4dafacd36c826e6334ab2 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -80,6 +80,7 @@ class RailEnv(Environment):
                  rail_generator=random_rail_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
+                 file_name=None
                  ):
         """
         Environment init.
@@ -110,6 +111,7 @@ class RailEnv(Environment):
         obs_builder_object: ObservationBuilder object
             ObservationBuilder-derived object that takes builds observation
             vectors for each agent.
+        file_name: you can load a pickle file.
         """
 
         self.rail_generator = rail_generator
@@ -117,14 +119,10 @@ class RailEnv(Environment):
         self.width = width
         self.height = height
 
-        self.obs_builder = obs_builder_object
-        self.obs_builder._set_env(self)
-
-        self.action_space = [1]
-        self.observation_space = self.obs_builder.observation_space  # updated on resets?
-
         self.rewards = [0] * number_of_agents
         self.done = False
+        self.obs_builder = obs_builder_object
+        self.obs_builder._set_env(self)
 
         self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
 
@@ -135,6 +133,14 @@ class RailEnv(Environment):
         self.agents = [None] * number_of_agents  # live agents
         self.agents_static = [None] * number_of_agents  # static agent information
         self.num_resets = 0
+        if file_name:
+            self.loaded_file = file_name
+        else:
+            self.loaded_file = None
+
+        self.action_space = [1]
+        self.observation_space = self.obs_builder.observation_space  # updated on resets?
+
         self.reset()
         self.num_resets = 0  # yes, set it to zero again!
 
@@ -175,6 +181,9 @@ class RailEnv(Environment):
         if replace_agents:
             self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
 
+        if self.loaded_file:
+            self.load(self.loaded_file)
+
         self.restart_agents()
 
         for i_agent in range(self.get_num_agents()):
@@ -425,6 +434,9 @@ class RailEnv(Environment):
             load_data = file_in.read()
             self.set_full_state_msg(load_data)
 
+    def load_pkl(self, pkl_data):
+        self.set_full_state_msg(pkl_data)
+
     def load_resource(self, package, resource):
         from importlib_resources import read_binary
         load_data = read_binary(package, resource)
diff --git a/tests/test_file_load.py b/tests/test_file_load.py
new file mode 100644
index 0000000000000000000000000000000000000000..af5644f3ee81d72641449e7184c078830f65bdc8
--- /dev/null
+++ b/tests/test_file_load.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import numpy as np
+
+from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from tests.simple_rail import make_simple_rail
+
+
+def test_load_pkl():
+    file_name = "test_pkl.pkl"
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=3,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+    env.save(file_name)
+    # initialize agents_static
+    rails_initial = env.rail.grid
+    agents_initial = env.agents
+
+    env = RailEnv(width=1,
+                  height=1,
+                  rail_generator=empty_rail_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  file_name=file_name
+                  )
+    rails_loaded = env.rail.grid
+    agents_loaded = env.agents
+
+    assert np.all(np.array_equal(rails_initial, rails_loaded))
+    assert agents_initial == agents_loaded
+
+    return