From 2a12614c26d5a11486526d674f76bbd8cfec3338 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Tue, 16 Jul 2019 10:48:29 -0400 Subject: [PATCH] initial improvement to test for generators --- tests/tests_generators.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 31dff253..8270685e 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -8,7 +8,7 @@ from flatland.envs.generators import rail_from_grid_transition_map, rail_from_fi from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from tests.simple_rail import make_simple_rail +from flatland.utils.simple_rail import make_simple_rail def test_empty_rail_generator(): @@ -122,7 +122,7 @@ def tests_rail_from_file(): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.save(file_name) - + dist_map_shape = np.shape(env.obs_builder.distance_map) # initialize agents_static rails_initial = env.rail.grid agents_initial = env.agents @@ -138,6 +138,7 @@ def tests_rail_from_file(): assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded + assert np.shape(env.obs_builder.distance_map) == dist_map_shape assert env.obs_builder.distance_map is not None # Test to save and load file without distance map. @@ -173,7 +174,6 @@ def tests_rail_from_file(): # Test to save with distance map and load without - # initialize agents_static env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), @@ -201,6 +201,11 @@ def tests_rail_from_file(): rails_loaded_4 = env4.rail.grid agents_loaded_4 = env4.agents + # Check that no distance map was saved + assert not hasattr(env2.obs_builder, "distance_map") assert np.all(np.array_equal(rails_initial_2, rails_loaded_4)) assert agents_initial_2 == agents_loaded_4 - assert env.obs_builder.distance_map is not None + + # Check that distance map was generated with correct shape + assert env4.obs_builder.distance_map is not None + assert np.shape(env4.obs_builder.distance_map) == dist_map_shape -- GitLab