From f1258872b377122d6facbdb4dd3dbc3994550dd4 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 14 Jul 2019 13:00:48 -0400
Subject: [PATCH] updated load and save function. Now also distance maps are
 stored. Additional package msgpack-numpy needed for ndarray. This saves tons
 of time when loading precomputed files. Updated test to test loading and
 saving with and without distance maps

---
 tests/tests_generators.py | 74 +++++++++++++++++++++++++++++++++++++--
 1 file changed, 72 insertions(+), 2 deletions(-)

diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 449b8329..31dff253 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -5,7 +5,7 @@ import numpy as np
 
 from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
     random_rail_generator, empty_rail_generator
-from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from tests.simple_rail import make_simple_rail
@@ -109,8 +109,12 @@ def test_rail_from_grid_transition_map():
 
 
 def tests_rail_from_file():
-    file_name = "test_pkl.pkl"
+    file_name = "test_with_distance_map.pkl"
+
+    # Test to save and load file with distance map.
+
     rail, rail_map = make_simple_rail()
+
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
@@ -118,6 +122,7 @@ def tests_rail_from_file():
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
     env.save(file_name)
+
     # initialize agents_static
     rails_initial = env.rail.grid
     agents_initial = env.agents
@@ -133,4 +138,69 @@ def tests_rail_from_file():
 
     assert np.all(np.array_equal(rails_initial, rails_loaded))
     assert agents_initial == agents_loaded
+    assert env.obs_builder.distance_map is not None
+
+    # Test to save and load file without distance map.
+
+    file_name_2 = "test_without_distance_map.pkl"
+
+    env2 = RailEnv(width=rail_map.shape[1],
+                   height=rail_map.shape[0],
+                   rail_generator=rail_from_grid_transition_map(rail),
+                   number_of_agents=3,
+                   obs_builder_object=GlobalObsForRailEnv(),
+                   )
+
+    env2.save(file_name_2)
+
+    # initialize agents_static
+    rails_initial_2 = env2.rail.grid
+    agents_initial_2 = env2.agents
 
+    env2 = RailEnv(width=1,
+                   height=1,
+                   rail_generator=rail_from_file(file_name_2),
+                   number_of_agents=1,
+                   obs_builder_object=GlobalObsForRailEnv(),
+                   )
+
+    rails_loaded_2 = env2.rail.grid
+    agents_loaded_2 = env2.agents
+
+    assert np.all(np.array_equal(rails_initial_2, rails_loaded_2))
+    assert agents_initial_2 == agents_loaded_2
+    assert not hasattr(env2.obs_builder, "distance_map")
+
+    # Test to save with distance map and load without
+
+    # initialize agents_static
+    env3 = RailEnv(width=1,
+                   height=1,
+                   rail_generator=rail_from_file(file_name),
+                   number_of_agents=1,
+                   obs_builder_object=GlobalObsForRailEnv(),
+                   )
+
+    rails_loaded_3 = env3.rail.grid
+    agents_loaded_3 = env3.agents
+
+    assert np.all(np.array_equal(rails_initial, rails_loaded_3))
+    assert agents_initial == agents_loaded_3
+    assert not hasattr(env2.obs_builder, "distance_map")
+
+    # Test to save without distance map and load with generating distance map
+
+    # initialize agents_static
+    env4 = RailEnv(width=1,
+                   height=1,
+                   rail_generator=rail_from_file(file_name_2),
+                   number_of_agents=1,
+                   obs_builder_object=TreeObsForRailEnv(max_depth=2),
+                   )
+
+    rails_loaded_4 = env4.rail.grid
+    agents_loaded_4 = env4.agents
+
+    assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
+    assert agents_initial_2 == agents_loaded_4
+    assert env.obs_builder.distance_map is not None
-- 
GitLab