From 02e9d8d0a98ca4a986c7c9c8effbdcc604e85c6b Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Tue, 16 Jul 2019 11:49:59 -0400
Subject: [PATCH] updated how generator check for distance map data. Updated
 test for generators to check that distance map is not recomputed when loaded
 correctly

---
 flatland/envs/generators.py   | 16 +++++++++++-----
 flatland/envs/observations.py |  3 +--
 flatland/envs/rail_env.py     |  8 ++++++--
 notebooks/Scene_Editor.ipynb  | 17 +++++++++++++++--
 4 files changed, 33 insertions(+), 11 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index e208a46d..ae2968f8 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -145,7 +145,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
         agents_target = [sg[1] for sg in start_goal[:num_agents]]
         agents_direction = start_dir[:num_agents]
 
-        return grid_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
+        return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
     return generator
 
@@ -193,7 +193,7 @@ def rail_from_manual_specifications_generator(rail_spec):
             rail,
             num_agents)
 
-        return rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
+        return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
     return generator
 
@@ -230,9 +230,15 @@ def rail_from_file(filename):
         agents_target = [a.target for a in agents_static]
         if b"distance_maps" in data.keys():
             distance_maps = data[b"distance_maps"]
-            return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps
+            if len(distance_maps) > 0:
+                print("Loading distance map")
+                return rail, agents_position, agents_direction, agents_target, [1.0] * len(
+                    agents_position), distance_maps
+            else:
+                return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
         else:
             return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+
     return generator
 
 
@@ -257,7 +263,7 @@ def rail_from_grid_transition_map(rail_map):
             rail_map,
             num_agents)
 
-        return rail_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
+        return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
     return generator
 
@@ -530,6 +536,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
             return_rail,
             num_agents)
 
-        return return_rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
+        return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
     return generator
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 3e07c05e..d0d678b8 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -52,7 +52,6 @@ class TreeObsForRailEnv(ObservationBuilder):
             for i in range(nb_agents):
                 if agents[i].target != self.agents_previous_reset[i].target:
                     compute_distance_map = True
-
         # Don't compute the distance map if it was loaded
         if self.agents_previous_reset is None and self.distance_map is not None:
             self.location_has_target = {tuple(agent.target): 1 for agent in agents}
@@ -65,7 +64,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
     def _compute_distance_map(self):
         agents = self.env.agents
-
+        print("Computing distance map")
         # For testing only --> To assert if a distance map need to be recomputed.
         self.distance_map_computed = True
         nb_agents = len(agents)
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index a0aa2e64..e098160d 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -480,8 +480,12 @@ class RailEnv(Environment):
 
     def save(self, filename):
         if hasattr(self.obs_builder, 'distance_map'):
-            with open(filename, "wb") as file_out:
-                file_out.write(self.get_full_state_dist_msg())
+            if len(self.obs_builder.distance_map) > 0:
+                with open(filename, "wb") as file_out:
+                    file_out.write(self.get_full_state_dist_msg())
+            else:
+                with open(filename, "wb") as file_out:
+                    file_out.write(self.get_full_state_msg())
         else:
             with open(filename, "wb") as file_out:
                 file_out.write(self.get_full_state_msg())
diff --git a/notebooks/Scene_Editor.ipynb b/notebooks/Scene_Editor.ipynb
index 877f8200..852fddb4 100644
--- a/notebooks/Scene_Editor.ipynb
+++ b/notebooks/Scene_Editor.ipynb
@@ -11,7 +11,20 @@
    "cell_type": "code",
    "execution_count": 1,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<style>.container { width:95% !important; }</style>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
    "source": [
     "from IPython.core.display import display, HTML\n",
     "display(HTML(\"<style>.container { width:95% !important; }</style>\"))"
@@ -57,7 +70,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "bd784815032c4d89803fa93399def9cf",
+       "model_id": "5a05a5f0569846dfa58fc6cfc222619a",
        "version_major": 2,
        "version_minor": 0
       },
-- 
GitLab