diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index d6c6da2c4fd622e8bfb92fca6a5a19c38e4ae973..bee91c117239358e8b8ef98caccb6329fa834cc1 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -9,9 +9,9 @@ from flatland.envs.agent_utils import EnvAgentStatic from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.rail_generators import complex_rail_generator, rail_from_file from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator +from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, schedule_from_file from flatland.utils.simple_rail import make_simple_rail """Tests for `flatland` package.""" @@ -228,3 +228,78 @@ def test_get_entry_directions(): # nowhere _assert((0, 0), [False, False, False, False]) + +def test_rail_env_reset(): + file_name = "test_rail_env_reset.pkl" + + # Test to save and load file with distance map. + + rail, rail_map = make_simple_rail() + + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=3, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + env.reset() + env.save(file_name) + dist_map_shape = np.shape(env.distance_map.get()) + # initialize agents_static + rails_initial = env.rail.grid + agents_initial = env.agents + + env2 = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name), + schedule_generator=schedule_from_file(file_name), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + env2.reset(False, False, False) + rails_loaded = env2.rail.grid + agents_loaded = env2.agents + + assert np.all(np.array_equal(rails_initial, rails_loaded)) + assert agents_initial == agents_loaded + + # Check that distance map was not recomputed + assert np.shape(env2.distance_map.get()) == dist_map_shape + assert env2.distance_map.get() is not None + + env3 = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name), + schedule_generator=schedule_from_file(file_name), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + env3.reset(False, True, False) + rails_loaded = env3.rail.grid + agents_loaded = env3.agents + + assert np.all(np.array_equal(rails_initial, rails_loaded)) + assert agents_initial == agents_loaded + + # Check that distance map was not recomputed + assert np.shape(env3.distance_map.get()) == dist_map_shape + assert env3.distance_map.get() is not None + + env4 = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name), + schedule_generator=schedule_from_file(file_name), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + env4.reset(True, False, False) + rails_loaded = env4.rail.grid + agents_loaded = env4.agents + + assert np.all(np.array_equal(rails_initial, rails_loaded)) + assert agents_initial == agents_loaded + + # Check that distance map was not recomputed + assert np.shape(env4.distance_map.get()) == dist_map_shape + assert env4.distance_map.get() is not None