diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py index 9a43affc9993151a2b7ca39f372b1410d04a570a..8a734c7626cf315d86c0463ea38383a28c75d31d 100644 --- a/reinforcement_learning/ppo_agent.py +++ b/reinforcement_learning/ppo_agent.py @@ -85,7 +85,7 @@ class ActorCriticModel(nn.Module): return obj def load(self, filename): - print("load policy from file", filename) + print("load model from file", filename) self.actor = self._load(self.actor, filename + ".actor") self.critic = self._load(self.critic, filename + ".value") @@ -284,6 +284,8 @@ class PPOPolicy(LearningPolicy): obj.load_state_dict(torch.load(filename, map_location=self.device)) except: print(" >> failed!") + else: + print(" >> file not found!") return obj def load(self, filename):