From 2b40ef62209d3ca0bc94ac85bb41a58a19c5779a Mon Sep 17 00:00:00 2001
From: Giacomo Spigler <spiglerg@gmail.com>
Date: Tue, 11 Jun 2019 11:24:01 +0200
Subject: [PATCH] fixed issue #60

---
 examples/simple_example_3.py  |  6 +++---
 flatland/envs/observations.py | 31 ++++++++++++++++++++++---------
 2 files changed, 25 insertions(+), 12 deletions(-)

diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py
index 41fd5a31..1661ef65 100644
--- a/examples/simple_example_3.py
+++ b/examples/simple_example_3.py
@@ -7,8 +7,8 @@ from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
-random.seed(10)
-np.random.seed(10)
+random.seed(1)
+np.random.seed(1)
 
 env = RailEnv(width=7,
               height=7,
@@ -19,7 +19,7 @@ env = RailEnv(width=7,
 # Print the observation vector for agent 0
 obs, all_rewards, done, _ = env.step({0: 0})
 for i in range(env.get_num_agents()):
-    env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5)
+    env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=7)
 
 env_renderer = RenderTool(env, gl="PIL")
 env_renderer.renderEnv(show=True, frames=True)
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 5cc3f26d..2214544a 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -31,19 +31,32 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
 
+        self.agents_previous_reset = None
+
     def reset(self):
         agents = self.env.agents
         nAgents = len(agents)
-        self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
-                                                    self.env.height,
-                                                    self.env.width,
-                                                    4))
-        self.max_dist = np.zeros(nAgents)
-
-        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
 
-        # Update local lookup table for all agents' target locations
-        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
+        compute_distance_map = True
+        if self.agents_previous_reset is not None:
+            if nAgents == len(self.agents_previous_reset):
+                compute_distance_map = False
+                for i in range(nAgents):
+                    if agents[i].target != self.agents_previous_reset[i].target:
+                        compute_distance_map = True
+        self.agents_previous_reset = agents
+
+        if compute_distance_map:
+            self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
+                                                        self.env.height,
+                                                        self.env.width,
+                                                        4))
+            self.max_dist = np.zeros(nAgents)
+
+            self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
+
+            # Update local lookup table for all agents' target locations
+            self.location_has_target = {tuple(agent.target): 1 for agent in agents}
 
     def _distance_map_walker(self, position, target_nr):
         """
-- 
GitLab