diff --git a/tests/test_file_load.py b/tests/test_file_load.py
index 6bc326047edb4730bd988c32c92c3a2460543eb4..af5644f3ee81d72641449e7184c078830f65bdc8 100644
--- a/tests/test_file_load.py
+++ b/tests/test_file_load.py
@@ -1,6 +1,8 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+import numpy as np
+
 from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -9,6 +11,7 @@ from tests.simple_rail import make_simple_rail
 
 
 def test_load_pkl():
+    file_name = "test_pkl.pkl"
     rail, rail_map = make_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
@@ -16,10 +19,10 @@ def test_load_pkl():
                   number_of_agents=3,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
-    env.save("test_pkl.pkl")
+    env.save(file_name)
     # initialize agents_static
-    obs_0 = env.reset(False, False)
-    file_name = "test_pkl.pkl"
+    rails_initial = env.rail.grid
+    agents_initial = env.agents
 
     env = RailEnv(width=1,
                   height=1,
@@ -28,6 +31,10 @@ def test_load_pkl():
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   file_name=file_name
                   )
-    obs_1 = env.reset(False, False)
-    assert obs_0 == obs_1
+    rails_loaded = env.rail.grid
+    agents_loaded = env.agents
+
+    assert np.all(np.array_equal(rails_initial, rails_loaded))
+    assert agents_initial == agents_loaded
+
     return