From 30e36369bec7605fa2fb9c7bb88b9f3faab067c5 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Tue, 16 Jul 2019 11:12:43 -0400 Subject: [PATCH] updated how generator check for distance map data. Updated test for generators to check that distance map is not recomputed when loaded correctly --- flatland/envs/generators.py | 3 ++- flatland/envs/observations.py | 4 ++++ tests/tests_generators.py | 4 ++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 9b33a55c..e208a46d 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -218,6 +218,7 @@ def rail_from_file(filename): with open(filename, "rb") as file_in: load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False) + grid = np.array(data[b"grid"]) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail.grid = grid @@ -227,7 +228,7 @@ def rail_from_file(filename): agents_position = [a.position for a in agents_static] agents_direction = [a.direction for a in agents_static] agents_target = [a.target for a in agents_static] - if len(data) > 3: + if b"distance_maps" in data.keys(): distance_maps = data[b"distance_maps"] return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps else: diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 38aadf93..3e07c05e 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -41,6 +41,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.tree_explored_actions = [1, 2, 3, 0] self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] self.distance_map = None + self.distance_map_computed = False def reset(self): agents = self.env.agents @@ -64,6 +65,9 @@ class TreeObsForRailEnv(ObservationBuilder): def _compute_distance_map(self): agents = self.env.agents + + # For testing only --> To assert if a distance map need to be recomputed. + self.distance_map_computed = True nb_agents = len(agents) self.distance_map = np.inf * np.ones(shape=(nb_agents, self.env.height, diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 8270685e..f97b071e 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -138,6 +138,9 @@ def tests_rail_from_file(): assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded + + # Check that distance map was not recomputed + assert env.obs_builder.distance_map_computed is False assert np.shape(env.obs_builder.distance_map) == dist_map_shape assert env.obs_builder.distance_map is not None @@ -207,5 +210,6 @@ def tests_rail_from_file(): assert agents_initial_2 == agents_loaded_4 # Check that distance map was generated with correct shape + assert env4.obs_builder.distance_map_computed is True assert env4.obs_builder.distance_map is not None assert np.shape(env4.obs_builder.distance_map) == dist_map_shape -- GitLab