From f1258872b377122d6facbdb4dd3dbc3994550dd4 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 14 Jul 2019 13:00:48 -0400 Subject: [PATCH] updated load and save function. Now also distance maps are stored. Additional package msgpack-numpy needed for ndarray. This saves tons of time when loading precomputed files. Updated test to test loading and saving with and without distance maps --- tests/tests_generators.py | 74 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 449b8329..31dff253 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -5,7 +5,7 @@ import numpy as np from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ random_rail_generator, empty_rail_generator -from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from tests.simple_rail import make_simple_rail @@ -109,8 +109,12 @@ def test_rail_from_grid_transition_map(): def tests_rail_from_file(): - file_name = "test_pkl.pkl" + file_name = "test_with_distance_map.pkl" + + # Test to save and load file with distance map. + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), @@ -118,6 +122,7 @@ def tests_rail_from_file(): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.save(file_name) + # initialize agents_static rails_initial = env.rail.grid agents_initial = env.agents @@ -133,4 +138,69 @@ def tests_rail_from_file(): assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded + assert env.obs_builder.distance_map is not None + + # Test to save and load file without distance map. + + file_name_2 = "test_without_distance_map.pkl" + + env2 = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + number_of_agents=3, + obs_builder_object=GlobalObsForRailEnv(), + ) + + env2.save(file_name_2) + + # initialize agents_static + rails_initial_2 = env2.rail.grid + agents_initial_2 = env2.agents + env2 = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name_2), + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv(), + ) + + rails_loaded_2 = env2.rail.grid + agents_loaded_2 = env2.agents + + assert np.all(np.array_equal(rails_initial_2, rails_loaded_2)) + assert agents_initial_2 == agents_loaded_2 + assert not hasattr(env2.obs_builder, "distance_map") + + # Test to save with distance map and load without + + # initialize agents_static + env3 = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name), + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv(), + ) + + rails_loaded_3 = env3.rail.grid + agents_loaded_3 = env3.agents + + assert np.all(np.array_equal(rails_initial, rails_loaded_3)) + assert agents_initial == agents_loaded_3 + assert not hasattr(env2.obs_builder, "distance_map") + + # Test to save without distance map and load with generating distance map + + # initialize agents_static + env4 = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name_2), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2), + ) + + rails_loaded_4 = env4.rail.grid + agents_loaded_4 = env4.agents + + assert np.all(np.array_equal(rails_initial_2, rails_loaded_4)) + assert agents_initial_2 == agents_loaded_4 + assert env.obs_builder.distance_map is not None -- GitLab