From 8eb248515e6ca114a6ac623e7f3c17d1b50261ba Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Thu, 12 Nov 2020 08:11:27 +0100
Subject: [PATCH] fix loading (CPU)

---
 reinforcement_learning/dddqn_policy.py | 4 ++--
 run.py                                 | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
index 0350a94..b34dd36 100644
--- a/reinforcement_learning/dddqn_policy.py
+++ b/reinforcement_learning/dddqn_policy.py
@@ -128,10 +128,10 @@ class DDDQNPolicy(Policy):
     def load(self, filename):
         try:
             if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"):
-                self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
+                self.qnetwork_local.load_state_dict(torch.load(filename + ".local", map_location=self.device))
                 print("qnetwork_local loaded ('{}')".format(filename + ".local"))
                 if not self.evaluation_mode:
-                    self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
+                    self.qnetwork_target.load_state_dict(torch.load(filename + ".target", map_location=self.device))
                     print("qnetwork_target loaded ('{}' )".format(filename + ".target"))
             else:
                 print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local",
diff --git a/run.py b/run.py
index 0b7108f..8cb2630 100644
--- a/run.py
+++ b/run.py
@@ -26,14 +26,14 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
 VERBOSE = True
 
 # Checkpoint to use (remember to push it!)
-checkpoint = "./checkpoints/201106234900-5400.pth"  # 15.64082361736683 Depth 1
+checkpoint = "./checkpoints/201111175340-5400.pth"
 
 # Use last action cache
 USE_ACTION_CACHE = False
 USE_DEAD_LOCK_AVOIDANCE_AGENT = False
 
 # Observation parameters (must match training parameters!)
-observation_tree_depth = 1
+observation_tree_depth = 2
 observation_radius = 10
 observation_max_path_depth = 30
 
-- 
GitLab