diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index 0350a9466746b6d4710921ff5195b40ccbf83d12..b34dd36dc1b8fb1d4ec373cd87bdc5371cfe74e2 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -128,10 +128,10 @@ class DDDQNPolicy(Policy): def load(self, filename): try: if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"): - self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) + self.qnetwork_local.load_state_dict(torch.load(filename + ".local", map_location=self.device)) print("qnetwork_local loaded ('{}')".format(filename + ".local")) if not self.evaluation_mode: - self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) + self.qnetwork_target.load_state_dict(torch.load(filename + ".target", map_location=self.device)) print("qnetwork_target loaded ('{}' )".format(filename + ".target")) else: print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local", diff --git a/run.py b/run.py index 0b7108fe2fe930fae71f07b1784b96b721e97fd9..8cb2630449d3b6442cb0aed3c0e3b1b31fb14484 100644 --- a/run.py +++ b/run.py @@ -26,14 +26,14 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy VERBOSE = True # Checkpoint to use (remember to push it!) -checkpoint = "./checkpoints/201106234900-5400.pth" # 15.64082361736683 Depth 1 +checkpoint = "./checkpoints/201111175340-5400.pth" # Use last action cache USE_ACTION_CACHE = False USE_DEAD_LOCK_AVOIDANCE_AGENT = False # Observation parameters (must match training parameters!) -observation_tree_depth = 1 +observation_tree_depth = 2 observation_radius = 10 observation_max_path_depth = 30