Something went wrong on our end
Forked from
adrian_egli / neurips2020-flatland-starter-kit
131 commits behind the upstream repository.
-
Egli Adrian (IT-SCI-API-PFI) authoredEgli Adrian (IT-SCI-API-PFI) authored
ppo_agent.py 4.99 KiB
import os
import numpy as np
import torch
from torch.distributions.categorical import Categorical
from reinforcement_learning.policy import Policy
from reinforcement_learning.ppo.model import PolicyNetwork
from reinforcement_learning.ppo.replay_memory import Episode, ReplayBuffer
BUFFER_SIZE = 128_000
BATCH_SIZE = 8192
GAMMA = 0.95
LR = 0.5e-4
CLIP_FACTOR = .005
UPDATE_EVERY = 30
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)
class PPOAgent(Policy):
def __init__(self, state_size, action_size, num_agents):
self.action_size = action_size
self.state_size = state_size
self.num_agents = num_agents
self.policy = PolicyNetwork(state_size, action_size).to(device)
self.old_policy = PolicyNetwork(state_size, action_size).to(device)
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR)
self.episodes = [Episode() for _ in range(num_agents)]
self.memory = ReplayBuffer(BUFFER_SIZE)
self.t_step = 0
self.loss = 0
def reset(self):
self.finished = [False] * len(self.episodes)
self.tot_reward = [0] * self.num_agents
# Decide on an action to take in the environment
def act(self, state, eps=None):
if eps is not None:
# Epsilon-greedy action selection
if np.random.random() < eps:
return np.random.choice(np.arange(self.action_size))
self.policy.eval()
with torch.no_grad():
output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
ret = Categorical(output).sample().item()
return ret
# Record the results of the agent's action and update the model
def step(self, handle, state, action, reward, next_state, done):
if not self.finished[handle]:
# Push experience into Episode memory
self.tot_reward[handle] += reward
if done == 1:
reward = 1 # self.tot_reward[handle]
else:
reward = 0
self.episodes[handle].push(state, action, reward, next_state, done)
# When we finish the episode, discount rewards and push the experience into replay memory
if done:
self.episodes[handle].discount_rewards(GAMMA)
self.memory.push_episode(self.episodes[handle])
self.episodes[handle].reset()
self.finished[handle] = True
# Perform a gradient update every UPDATE_EVERY time steps
self.t_step = (self.t_step + 1) % UPDATE_EVERY
if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4:
self._learn(*self.memory.sample(BATCH_SIZE, device))
def _clip_gradient(self, model, clip):
for p in model.parameters():
p.grad.data.clamp_(-clip, clip)
return
"""Computes a gradient clipping coefficient based on gradient norm."""
totalnorm = 0
for p in model.parameters():
if p.grad is not None:
modulenorm = p.grad.data.norm()
totalnorm += modulenorm ** 2
totalnorm = np.sqrt(totalnorm)
coeff = min(1, clip / (totalnorm + 1e-6))
for p in model.parameters():
if p.grad is not None:
p.grad.mul_(coeff)
def _learn(self, states, actions, rewards, next_state, done):
self.policy.train()
responsible_outputs = torch.gather(self.policy(states), 1, actions)
old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach()
# rewards = rewards - rewards.mean()
ratio = responsible_outputs / (old_responsible_outputs + 1e-5)
clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR)
loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean()
self.loss = loss
# Compute loss and perform a gradient step
self.old_policy.load_state_dict(self.policy.state_dict())
self.optimizer.zero_grad()
loss.backward()
# self._clip_gradient(self.policy, 1.0)
self.optimizer.step()
# Checkpointing methods
def save(self, filename):
# print("Saving model from checkpoint:", filename)
torch.save(self.policy.state_dict(), filename + ".policy")
torch.save(self.optimizer.state_dict(), filename + ".optimizer")
def load(self, filename):
print("load policy from file", filename)
if os.path.exists(filename + ".policy"):
print(' >> ', filename + ".policy")
try:
self.policy.load_state_dict(torch.load(filename + ".policy", map_location=device))
except:
print(" >> failed!")
pass
if os.path.exists(filename + ".optimizer"):
print(' >> ', filename + ".optimizer")
try:
self.optimizer.load_state_dict(torch.load(filename + ".optimizer", map_location=device))
except:
print(" >> failed!")
pass