From c53739a97639afc831a594c2d3c5facc10f062f0 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Sun, 15 Nov 2020 20:07:22 +0100
Subject: [PATCH] refactored file name

---
 run.py                 |  2 +-
 utils/fast_tree_obs.py | 11 ++++++++---
 2 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/run.py b/run.py
index 143f44a..048dcde 100644
--- a/run.py
+++ b/run.py
@@ -27,7 +27,7 @@ VERBOSE = True
 
 # Checkpoint to use (remember to push it!)
 # checkpoint = "./checkpoints/201112143850-5400.pth" # 21.220418678677177 DEPTH=2 AGENTS=10
-checkpoint = "./checkpoints/201113211844-6200.pth" # 19.690047767961005 DEPTH=2 AGENTS=20
+checkpoint = "./checkpoints/201113211844-6700.pth" # 19.690047767961005 DEPTH=2 AGENTS=20
 
 
 # Use last action cache
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
index 0666ef4..c388d2a 100755
--- a/utils/fast_tree_obs.py
+++ b/utils/fast_tree_obs.py
@@ -1,3 +1,5 @@
+from typing import List, Optional
+
 import numpy as np
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4_utils import get_new_position
@@ -232,6 +234,12 @@ class FastTreeObs(ObservationBuilder):
 
         return has_opp_agent, has_same_agent, has_target, visited
 
+    def get_many(self, handles: Optional[List[int]] = None):
+        self.dead_lock_avoidance_agent.start_step()
+        observations = super().get_many(handles)
+        self.dead_lock_avoidance_agent.end_step()
+        return observations
+
     def get(self, handle):
         # all values are [0,1]
         # observation[0]  : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
@@ -262,9 +270,6 @@ class FastTreeObs(ObservationBuilder):
         # observation[25] : If there is a switch on the path which agent can not use -> 1
         # observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1
 
-        if handle == 0:
-            self.dead_lock_avoidance_agent.start_step()
-
         observation = np.zeros(self.observation_dim)
         visited = []
         agent = self.env.agents[handle]
-- 
GitLab