Commit 0d488563 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

torch.distributions & aug in aux phase

parent 925d66aa
......@@ -9,6 +9,7 @@ from .utils import *
import time
torch, nn = try_import_torch()
import torch.distributions as td
class CustomTorchPolicy(TorchPolicy):
"""Example of a random policy
......@@ -84,6 +85,7 @@ class CustomTorchPolicy(TorchPolicy):
self.ent_coef = config['entropy_coeff']
self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],))
self.make_distr = dist_build(action_space)
def to_tensor(self, arr):
return torch.from_numpy(arr).to(self.device)
......@@ -135,8 +137,6 @@ class CustomTorchPolicy(TorchPolicy):
end = start + nbatch_train
values[start:end], _ = self.model.vf_pi(samples['obs'][start:end], ret_numpy=True, no_grad=True, to_torch=True)
## GAE
mb_values = unroll(values, ts)
mb_returns = np.zeros_like(mb_rewards)
......@@ -205,12 +205,10 @@ class CustomTorchPolicy(TorchPolicy):
for g in self.optimizer.param_groups:
g['lr'] = lr
# Advantages are normalized with full size batch instead of memory limited batch
# advs = returns - values
# advs = (advs - torch.mean(advs)) / (torch.std(advs) + 1e-8)
vpred, pi_logits = self.model.vf_pi(obs, ret_numpy=False, no_grad=False, to_torch=False)
neglogpac = neglogp_actions(pi_logits, actions)
entropy = torch.mean(pi_entropy(pi_logits))
pd = self.make_distr(pi_logits)
neglogpac = -pd.log_prob(actions[...,None]).squeeze(1)
entropy = torch.mean(pd.entropy())
vf_loss = .5 * torch.mean(torch.pow((vpred - returns), 2))
......@@ -262,23 +260,22 @@ class CustomTorchPolicy(TorchPolicy):
self.retune_selector.retune_done()
def tune_policy(self, obs, target_vf, target_pi):
# obs_aug = np.empty(obs.shape, obs.dtype)
# aug_idx = np.random.randint(3, size=len(obs))
# obs_aug[aug_idx == 0] = pad_and_random_crop(obs[aug_idx == 0], 64, 10)
# obs_aug[aug_idx == 1] = random_cutout_color(obs[aug_idx == 1], 10, 30)
# obs_aug[aug_idx == 2] = obs[aug_idx == 2]
# obs_aug = self.to_tensor(obs_aug)
obs = self.to_tensor(obs)
with torch.no_grad():
tpi_log_softmax = nn.functional.log_softmax(target_pi, dim=1)
tpi_softmax = torch.exp(tpi_log_softmax)
vpred, pi_logits = self.model.vf_pi(obs, ret_numpy=False, no_grad=False, to_torch=False)
obs_aug = np.empty(obs.shape, obs.dtype)
aug_idx = np.random.randint(6, size=len(obs))
obs_aug[aug_idx == 0] = pad_and_random_crop(obs[aug_idx == 0], 64, 10)
obs_aug[aug_idx == 1] = random_cutout_color(obs[aug_idx == 1], 10, 30)
obs_aug[aug_idx >= 2] = obs[aug_idx >= 2]
obs_aug = self.to_tensor(obs_aug)
vpred, pi_logits = self.model.vf_pi(obs_aug, ret_numpy=False, no_grad=False, to_torch=False)
aux_vpred = self.model.aux_value_function()
pi_log_softmax = nn.functional.log_softmax(pi_logits, dim=1)
pi_loss = torch.mean(torch.sum(tpi_softmax * (tpi_log_softmax - pi_log_softmax) , dim=1)) # kl_div torch 1.3.1 has numerical issues
vf_loss = .5 * torch.mean(torch.pow(vpred - target_vf, 2))
aux_loss = .5 * torch.mean(torch.pow(aux_vpred - target_vf, 2))
target_pd = self.make_distr(target_pi)
pd = self.make_distr(pi_logits)
pi_loss = td.kl_divergence(target_pd, pd).mean()
loss = vf_loss + pi_loss + aux_loss
loss.backward()
......
......@@ -4,6 +4,15 @@ from collections import deque
from skimage.util import view_as_windows
torch, nn = try_import_torch()
import torch.distributions as td
from functools import partial
def _make_categorical(x, ncat, shape):
x = x.reshape((x.shape[0], shape, ncat))
return td.Categorical(logits=x)
def dist_build(ac_space):
return partial(_make_categorical, shape=1, ncat=ac_space.n)
def neglogp_actions(pi_logits, actions):
return nn.functional.cross_entropy(pi_logits, actions, reduction='none')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment