From 30e36369bec7605fa2fb9c7bb88b9f3faab067c5 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Tue, 16 Jul 2019 11:12:43 -0400
Subject: [PATCH] updated how generator check for distance map data. Updated
 test for generators to check that distance map is not recomputed when loaded
 correctly

---
 flatland/envs/generators.py   | 3 ++-
 flatland/envs/observations.py | 4 ++++
 tests/tests_generators.py     | 4 ++++
 3 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 9b33a55c..e208a46d 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -218,6 +218,7 @@ def rail_from_file(filename):
         with open(filename, "rb") as file_in:
             load_data = file_in.read()
         data = msgpack.unpackb(load_data, use_list=False)
+
         grid = np.array(data[b"grid"])
         rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
         rail.grid = grid
@@ -227,7 +228,7 @@ def rail_from_file(filename):
         agents_position = [a.position for a in agents_static]
         agents_direction = [a.direction for a in agents_static]
         agents_target = [a.target for a in agents_static]
-        if len(data) > 3:
+        if b"distance_maps" in data.keys():
             distance_maps = data[b"distance_maps"]
             return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps
         else:
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 38aadf93..3e07c05e 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -41,6 +41,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.tree_explored_actions = [1, 2, 3, 0]
         self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
         self.distance_map = None
+        self.distance_map_computed = False
 
     def reset(self):
         agents = self.env.agents
@@ -64,6 +65,9 @@ class TreeObsForRailEnv(ObservationBuilder):
 
     def _compute_distance_map(self):
         agents = self.env.agents
+
+        # For testing only --> To assert if a distance map need to be recomputed.
+        self.distance_map_computed = True
         nb_agents = len(agents)
         self.distance_map = np.inf * np.ones(shape=(nb_agents,
                                                     self.env.height,
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 8270685e..f97b071e 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -138,6 +138,9 @@ def tests_rail_from_file():
 
     assert np.all(np.array_equal(rails_initial, rails_loaded))
     assert agents_initial == agents_loaded
+
+    # Check that distance map was not recomputed
+    assert env.obs_builder.distance_map_computed is False
     assert np.shape(env.obs_builder.distance_map) == dist_map_shape
     assert env.obs_builder.distance_map is not None
 
@@ -207,5 +210,6 @@ def tests_rail_from_file():
     assert agents_initial_2 == agents_loaded_4
 
     # Check that distance map was generated with correct shape
+    assert env4.obs_builder.distance_map_computed is True
     assert env4.obs_builder.distance_map is not None
     assert np.shape(env4.obs_builder.distance_map) == dist_map_shape
-- 
GitLab