Skip to content
Snippets Groups Projects
Commit f1258872 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated load and save function. Now also distance maps are stored.

Additional package msgpack-numpy needed for ndarray.
This saves tons of time when loading precomputed files.
Updated test to test loading and saving with and without distance maps
parent 4d538a41
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
random_rail_generator, empty_rail_generator random_rail_generator, empty_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from tests.simple_rail import make_simple_rail from tests.simple_rail import make_simple_rail
...@@ -109,8 +109,12 @@ def test_rail_from_grid_transition_map(): ...@@ -109,8 +109,12 @@ def test_rail_from_grid_transition_map():
def tests_rail_from_file(): def tests_rail_from_file():
file_name = "test_pkl.pkl" file_name = "test_with_distance_map.pkl"
# Test to save and load file with distance map.
rail, rail_map = make_simple_rail() rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
...@@ -118,6 +122,7 @@ def tests_rail_from_file(): ...@@ -118,6 +122,7 @@ def tests_rail_from_file():
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
) )
env.save(file_name) env.save(file_name)
# initialize agents_static # initialize agents_static
rails_initial = env.rail.grid rails_initial = env.rail.grid
agents_initial = env.agents agents_initial = env.agents
...@@ -133,4 +138,69 @@ def tests_rail_from_file(): ...@@ -133,4 +138,69 @@ def tests_rail_from_file():
assert np.all(np.array_equal(rails_initial, rails_loaded)) assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded assert agents_initial == agents_loaded
assert env.obs_builder.distance_map is not None
# Test to save and load file without distance map.
file_name_2 = "test_without_distance_map.pkl"
env2 = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=3,
obs_builder_object=GlobalObsForRailEnv(),
)
env2.save(file_name_2)
# initialize agents_static
rails_initial_2 = env2.rail.grid
agents_initial_2 = env2.agents
env2 = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name_2),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
)
rails_loaded_2 = env2.rail.grid
agents_loaded_2 = env2.agents
assert np.all(np.array_equal(rails_initial_2, rails_loaded_2))
assert agents_initial_2 == agents_loaded_2
assert not hasattr(env2.obs_builder, "distance_map")
# Test to save with distance map and load without
# initialize agents_static
env3 = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
)
rails_loaded_3 = env3.rail.grid
agents_loaded_3 = env3.agents
assert np.all(np.array_equal(rails_initial, rails_loaded_3))
assert agents_initial == agents_loaded_3
assert not hasattr(env2.obs_builder, "distance_map")
# Test to save without distance map and load with generating distance map
# initialize agents_static
env4 = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name_2),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2),
)
rails_loaded_4 = env4.rail.grid
agents_loaded_4 = env4.agents
assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
assert agents_initial_2 == agents_loaded_4
assert env.obs_builder.distance_map is not None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment