diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index d6c6da2c4fd622e8bfb92fca6a5a19c38e4ae973..bee91c117239358e8b8ef98caccb6329fa834cc1 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -9,9 +9,9 @@ from flatland.envs.agent_utils import EnvAgentStatic
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
-from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.rail_generators import complex_rail_generator, rail_from_file
 from flatland.envs.rail_generators import rail_from_grid_transition_map
-from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator
+from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, schedule_from_file
 from flatland.utils.simple_rail import make_simple_rail
 
 """Tests for `flatland` package."""
@@ -228,3 +228,78 @@ def test_get_entry_directions():
 
     # nowhere
     _assert((0, 0), [False, False, False, False])
+
+def test_rail_env_reset():
+    file_name = "test_rail_env_reset.pkl"
+
+    # Test to save and load file with distance map.
+
+    rail, rail_map = make_simple_rail()
+
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=3,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+    env.reset()
+    env.save(file_name)
+    dist_map_shape = np.shape(env.distance_map.get())
+    # initialize agents_static
+    rails_initial = env.rail.grid
+    agents_initial = env.agents
+
+    env2 = RailEnv(width=1,
+                  height=1,
+                  rail_generator=rail_from_file(file_name),
+                  schedule_generator=schedule_from_file(file_name),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+    env2.reset(False, False, False)
+    rails_loaded = env2.rail.grid
+    agents_loaded = env2.agents
+
+    assert np.all(np.array_equal(rails_initial, rails_loaded))
+    assert agents_initial == agents_loaded
+
+    # Check that distance map was not recomputed
+    assert np.shape(env2.distance_map.get()) == dist_map_shape
+    assert env2.distance_map.get() is not None
+
+    env3 = RailEnv(width=1,
+                  height=1,
+                  rail_generator=rail_from_file(file_name),
+                  schedule_generator=schedule_from_file(file_name),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+    env3.reset(False, True, False)
+    rails_loaded = env3.rail.grid
+    agents_loaded = env3.agents
+
+    assert np.all(np.array_equal(rails_initial, rails_loaded))
+    assert agents_initial == agents_loaded
+
+    # Check that distance map was not recomputed
+    assert np.shape(env3.distance_map.get()) == dist_map_shape
+    assert env3.distance_map.get() is not None
+
+    env4 = RailEnv(width=1,
+                  height=1,
+                  rail_generator=rail_from_file(file_name),
+                  schedule_generator=schedule_from_file(file_name),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+    env4.reset(True, False, False)
+    rails_loaded = env4.rail.grid
+    agents_loaded = env4.agents
+
+    assert np.all(np.array_equal(rails_initial, rails_loaded))
+    assert agents_initial == agents_loaded
+
+    # Check that distance map was not recomputed
+    assert np.shape(env4.distance_map.get()) == dist_map_shape
+    assert env4.distance_map.get() is not None