From c63e78aebd33ec11607d8d9118a80f2a53a39510 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Tue, 17 Sep 2019 16:12:28 +0200
Subject: [PATCH] Refactoring: move logic to compute new distance_map into get
 method

---
 flatland/envs/distance_map.py | 54 +++++++++++++++++++++--------------
 flatland/envs/rail_env.py     |  2 +-
 tests/tests_generators.py     |  2 --
 3 files changed, 33 insertions(+), 25 deletions(-)

diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py
index d155042e..94019319 100644
--- a/flatland/envs/distance_map.py
+++ b/flatland/envs/distance_map.py
@@ -1,5 +1,5 @@
 from collections import deque
-from typing import List
+from typing import List, Optional
 
 import numpy as np
 
@@ -13,8 +13,10 @@ class DistanceMap:
         self.env_height = env_height
         self.env_width = env_width
         self.distance_map = None
-        self.distance_map_computed = False
-        self.agents_previous_reset = None
+        self.agents_previous_computation = None
+        self.reset_was_called = False
+        self.agents: List[EnvAgent] = agents
+        self.rail: Optional[GridTransitionMap] = None
 
     """
     Set the distance map
@@ -26,31 +28,39 @@ class DistanceMap:
     Get the distance map
     """
     def get(self) -> np.ndarray:
-        return self.distance_map
 
-    """
-    Compute the distance map
-    """
-    def compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
+        if self.reset_was_called:
+            self.reset_was_called = False
+
+            nb_agents = len(self.agents)
+            compute_distance_map = True
+            if self.agents_previous_computation is not None and nb_agents == len(self.agents_previous_computation):
+                compute_distance_map = False
+                for i in range(nb_agents):
+                    if self.agents[i].target != self.agents_previous_computation[i].target:
+                        compute_distance_map = True
+            # Don't compute the distance map if it was loaded
+            if self.agents_previous_computation is None and self.distance_map is not None:
+                compute_distance_map = False
 
-        nb_agents = len(agents)
-        compute_distance_map = True
-        if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
-            compute_distance_map = False
-            for i in range(nb_agents):
-                if agents[i].target != self.agents_previous_reset[i].target:
-                    compute_distance_map = True
-        # Don't compute the distance map if it was loaded
-        if self.agents_previous_reset is None and self.distance_map is not None:
-            compute_distance_map = False
+            if compute_distance_map:
+                self._compute(self.agents, self.rail)
 
-        if compute_distance_map:
-            self._compute(agents, rail)
+        elif self.distance_map is None:
+            self._compute(self.agents, self.rail)
 
-        self.agents_previous_reset = agents
+        return self.distance_map
+
+    """
+    Reset the distance map
+    """
+    def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
+        self.reset_was_called = True
+        self.agents = agents
+        self.rail = rail
 
     def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
-        self.distance_map_computed = True
+        self.agents_previous_computation = self.agents
         self.distance_map = np.inf * np.ones(shape=(len(agents),
                                                     self.env_height,
                                                     self.env_width,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 2d358a29..f5881836 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -276,7 +276,7 @@ class RailEnv(Environment):
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
         self.observation_space = self.obs_builder.observation_space  # <-- change on reset?
-        self.distance_map.compute(self.agents, self.rail)
+        self.distance_map.reset(self.agents, self.rail)
 
         # Return the new observation vectors for each agent
         return self._get_observations()
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 8918a585..4c925789 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -148,7 +148,6 @@ def tests_rail_from_file():
     assert agents_initial == agents_loaded
 
     # Check that distance map was not recomputed
-    assert env.distance_map.distance_map_computed is False
     assert np.shape(env.distance_map.get()) == dist_map_shape
     assert env.distance_map.get() is not None
 
@@ -222,6 +221,5 @@ def tests_rail_from_file():
     assert agents_initial_2 == agents_loaded_4
 
     # Check that distance map was generated with correct shape
-    assert env4.distance_map.distance_map_computed is True
     assert env4.distance_map.get() is not None
     assert np.shape(env4.distance_map.get()) == dist_map_shape
-- 
GitLab