diff --git a/tests/test_file_load.py b/tests/test_file_load.py index 6bc326047edb4730bd988c32c92c3a2460543eb4..af5644f3ee81d72641449e7184c078830f65bdc8 100644 --- a/tests/test_file_load.py +++ b/tests/test_file_load.py @@ -1,6 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import numpy as np + from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -9,6 +11,7 @@ from tests.simple_rail import make_simple_rail def test_load_pkl(): + file_name = "test_pkl.pkl" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], @@ -16,10 +19,10 @@ def test_load_pkl(): number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - env.save("test_pkl.pkl") + env.save(file_name) # initialize agents_static - obs_0 = env.reset(False, False) - file_name = "test_pkl.pkl" + rails_initial = env.rail.grid + agents_initial = env.agents env = RailEnv(width=1, height=1, @@ -28,6 +31,10 @@ def test_load_pkl(): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), file_name=file_name ) - obs_1 = env.reset(False, False) - assert obs_0 == obs_1 + rails_loaded = env.rail.grid + agents_loaded = env.agents + + assert np.all(np.array_equal(rails_initial, rails_loaded)) + assert agents_initial == agents_loaded + return