Skip to content
Snippets Groups Projects
Commit 8eb24851 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

fix loading (CPU)

parent 5c8a88cf
No related branches found
No related tags found
No related merge requests found
...@@ -128,10 +128,10 @@ class DDDQNPolicy(Policy): ...@@ -128,10 +128,10 @@ class DDDQNPolicy(Policy):
def load(self, filename): def load(self, filename):
try: try:
if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"): if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"):
self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) self.qnetwork_local.load_state_dict(torch.load(filename + ".local", map_location=self.device))
print("qnetwork_local loaded ('{}')".format(filename + ".local")) print("qnetwork_local loaded ('{}')".format(filename + ".local"))
if not self.evaluation_mode: if not self.evaluation_mode:
self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) self.qnetwork_target.load_state_dict(torch.load(filename + ".target", map_location=self.device))
print("qnetwork_target loaded ('{}' )".format(filename + ".target")) print("qnetwork_target loaded ('{}' )".format(filename + ".target"))
else: else:
print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local", print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local",
......
...@@ -26,14 +26,14 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy ...@@ -26,14 +26,14 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
VERBOSE = True VERBOSE = True
# Checkpoint to use (remember to push it!) # Checkpoint to use (remember to push it!)
checkpoint = "./checkpoints/201106234900-5400.pth" # 15.64082361736683 Depth 1 checkpoint = "./checkpoints/201111175340-5400.pth"
# Use last action cache # Use last action cache
USE_ACTION_CACHE = False USE_ACTION_CACHE = False
USE_DEAD_LOCK_AVOIDANCE_AGENT = False USE_DEAD_LOCK_AVOIDANCE_AGENT = False
# Observation parameters (must match training parameters!) # Observation parameters (must match training parameters!)
observation_tree_depth = 1 observation_tree_depth = 2
observation_radius = 10 observation_radius = 10
observation_max_path_depth = 30 observation_max_path_depth = 30
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment