diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py index a7431f85201def6f189ccdc6101a89428b598e47..350119a225dff9feef6f8ab0589e476126f4ac2b 100644 --- a/reinforcement_learning/ppo/ppo_agent.py +++ b/reinforcement_learning/ppo/ppo_agent.py @@ -39,6 +39,11 @@ class PPOAgent(Policy): # Decide on an action to take in the environment def act(self, state, eps=None): + if eps is not None: + # Epsilon-greedy action selection + if np.random.random() < eps: + return np.random.choice(np.arange(self.action_size)) + self.policy.eval() with torch.no_grad(): output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))