From 2a12614c26d5a11486526d674f76bbd8cfec3338 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Tue, 16 Jul 2019 10:48:29 -0400
Subject: [PATCH] initial improvement to test for generators

---
 tests/tests_generators.py | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 31dff253..8270685e 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -8,7 +8,7 @@ from flatland.envs.generators import rail_from_grid_transition_map, rail_from_fi
 from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
-from tests.simple_rail import make_simple_rail
+from flatland.utils.simple_rail import make_simple_rail
 
 
 def test_empty_rail_generator():
@@ -122,7 +122,7 @@ def tests_rail_from_file():
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
     env.save(file_name)
-
+    dist_map_shape = np.shape(env.obs_builder.distance_map)
     # initialize agents_static
     rails_initial = env.rail.grid
     agents_initial = env.agents
@@ -138,6 +138,7 @@ def tests_rail_from_file():
 
     assert np.all(np.array_equal(rails_initial, rails_loaded))
     assert agents_initial == agents_loaded
+    assert np.shape(env.obs_builder.distance_map) == dist_map_shape
     assert env.obs_builder.distance_map is not None
 
     # Test to save and load file without distance map.
@@ -173,7 +174,6 @@ def tests_rail_from_file():
 
     # Test to save with distance map and load without
 
-    # initialize agents_static
     env3 = RailEnv(width=1,
                    height=1,
                    rail_generator=rail_from_file(file_name),
@@ -201,6 +201,11 @@ def tests_rail_from_file():
     rails_loaded_4 = env4.rail.grid
     agents_loaded_4 = env4.agents
 
+    # Check that no distance map was saved
+    assert not hasattr(env2.obs_builder, "distance_map")
     assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
     assert agents_initial_2 == agents_loaded_4
-    assert env.obs_builder.distance_map is not None
+
+    # Check that distance map was generated with correct shape
+    assert env4.obs_builder.distance_map is not None
+    assert np.shape(env4.obs_builder.distance_map) == dist_map_shape
-- 
GitLab