From c53739a97639afc831a594c2d3c5facc10f062f0 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Sun, 15 Nov 2020 20:07:22 +0100 Subject: [PATCH] refactored file name --- run.py | 2 +- utils/fast_tree_obs.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/run.py b/run.py index 143f44a..048dcde 100644 --- a/run.py +++ b/run.py @@ -27,7 +27,7 @@ VERBOSE = True # Checkpoint to use (remember to push it!) # checkpoint = "./checkpoints/201112143850-5400.pth" # 21.220418678677177 DEPTH=2 AGENTS=10 -checkpoint = "./checkpoints/201113211844-6200.pth" # 19.690047767961005 DEPTH=2 AGENTS=20 +checkpoint = "./checkpoints/201113211844-6700.pth" # 19.690047767961005 DEPTH=2 AGENTS=20 # Use last action cache diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index 0666ef4..c388d2a 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -1,3 +1,5 @@ +from typing import List, Optional + import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4_utils import get_new_position @@ -232,6 +234,12 @@ class FastTreeObs(ObservationBuilder): return has_opp_agent, has_same_agent, has_target, visited + def get_many(self, handles: Optional[List[int]] = None): + self.dead_lock_avoidance_agent.start_step() + observations = super().get_many(handles) + self.dead_lock_avoidance_agent.end_step() + return observations + def get(self, handle): # all values are [0,1] # observation[0] : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path @@ -262,9 +270,6 @@ class FastTreeObs(ObservationBuilder): # observation[25] : If there is a switch on the path which agent can not use -> 1 # observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1 - if handle == 0: - self.dead_lock_avoidance_agent.start_step() - observation = np.zeros(self.observation_dim) visited = [] agent = self.env.agents[handle] -- GitLab