Skip to content
Snippets Groups Projects
ppo_agent.py 4.99 KiB
import os

import numpy as np
import torch
from torch.distributions.categorical import Categorical

from reinforcement_learning.policy import Policy
from reinforcement_learning.ppo.model import PolicyNetwork
from reinforcement_learning.ppo.replay_memory import Episode, ReplayBuffer

BUFFER_SIZE = 128_000
BATCH_SIZE = 8192
GAMMA = 0.95
LR = 0.5e-4
CLIP_FACTOR = .005
UPDATE_EVERY = 30

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)


class PPOAgent(Policy):
    def __init__(self, state_size, action_size, num_agents):
        self.action_size = action_size
        self.state_size = state_size
        self.num_agents = num_agents
        self.policy = PolicyNetwork(state_size, action_size).to(device)
        self.old_policy = PolicyNetwork(state_size, action_size).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR)
        self.episodes = [Episode() for _ in range(num_agents)]
        self.memory = ReplayBuffer(BUFFER_SIZE)
        self.t_step = 0
        self.loss = 0

    def reset(self):
        self.finished = [False] * len(self.episodes)
        self.tot_reward = [0] * self.num_agents

    # Decide on an action to take in the environment

    def act(self, state, eps=None):
        if eps is not None:
            # Epsilon-greedy action selection
            if np.random.random() < eps:
                return np.random.choice(np.arange(self.action_size))

        self.policy.eval()
        with torch.no_grad():
            output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
            ret = Categorical(output).sample().item()
            return ret

    # Record the results of the agent's action and update the model
    def step(self, handle, state, action, reward, next_state, done):
        if not self.finished[handle]:
            # Push experience into Episode memory
            self.tot_reward[handle] += reward
            if done == 1:
                reward = 1  # self.tot_reward[handle]
            else:
                reward = 0

            self.episodes[handle].push(state, action, reward, next_state, done)

            # When we finish the episode, discount rewards and push the experience into replay memory
            if done:
                self.episodes[handle].discount_rewards(GAMMA)
                self.memory.push_episode(self.episodes[handle])
                self.episodes[handle].reset()
                self.finished[handle] = True

        # Perform a gradient update every UPDATE_EVERY time steps
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4:
            self._learn(*self.memory.sample(BATCH_SIZE, device))

    def _clip_gradient(self, model, clip):

        for p in model.parameters():
            p.grad.data.clamp_(-clip, clip)
        return

        """Computes a gradient clipping coefficient based on gradient norm."""
        totalnorm = 0
        for p in model.parameters():
            if p.grad is not None:
                modulenorm = p.grad.data.norm()
                totalnorm += modulenorm ** 2
        totalnorm = np.sqrt(totalnorm)
        coeff = min(1, clip / (totalnorm + 1e-6))

        for p in model.parameters():
            if p.grad is not None:
                p.grad.mul_(coeff)

    def _learn(self, states, actions, rewards, next_state, done):
        self.policy.train()

        responsible_outputs = torch.gather(self.policy(states), 1, actions)
        old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach()

        # rewards = rewards - rewards.mean()
        ratio = responsible_outputs / (old_responsible_outputs + 1e-5)
        clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR)
        loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean()
        self.loss = loss

        # Compute loss and perform a gradient step
        self.old_policy.load_state_dict(self.policy.state_dict())
        self.optimizer.zero_grad()
        loss.backward()
        # self._clip_gradient(self.policy, 1.0)
        self.optimizer.step()

    # Checkpointing methods
    def save(self, filename):
        # print("Saving model from checkpoint:", filename)
        torch.save(self.policy.state_dict(), filename + ".policy")
        torch.save(self.optimizer.state_dict(), filename + ".optimizer")

    def load(self, filename):
        print("load policy from file", filename)
        if os.path.exists(filename + ".policy"):
            print(' >> ', filename + ".policy")
            try:
                self.policy.load_state_dict(torch.load(filename + ".policy", map_location=device))
            except:
                print(" >> failed!")
                pass
        if os.path.exists(filename + ".optimizer"):
            print(' >> ', filename + ".optimizer")
            try:
                self.optimizer.load_state_dict(torch.load(filename + ".optimizer", map_location=device))
            except:
                print(" >> failed!")
                pass