Commit 6d45b551 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

ppg initial commit

parent ade85808
......@@ -44,7 +44,8 @@ class CustomTorchPolicy(TorchPolicy):
)
self.framework = "torch"
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
self.aux_optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
self.max_reward = self.config['env_config']['return_max']
self.rewnorm = RewardNormalizer(cliprew=self.max_reward) ## TODO: Might need to go to custom state
self.reward_deque = deque(maxlen=100)
......@@ -124,14 +125,6 @@ class CustomTorchPolicy(TorchPolicy):
return {} # Not doing last optimization step - This is intentional due to noisy gradients
obs = samples['obs']
## Distill with augmentation
should_retune = self.retune_selector.update(obs)
if should_retune:
self.retune_with_augmentation(obs)
self.update_batch_time()
return {}
## Value prediction
next_obs = unroll(samples['new_obs'], ts)[-1]
......@@ -190,6 +183,13 @@ class CustomTorchPolicy(TorchPolicy):
self._batch_train(apply_grad, self.accumulate_train_batches,
lrnow, cliprange, vfcliprange, max_grad_norm, ent_coef, vf_coef, *slices)
## Distill with aux head
should_retune = self.retune_selector.update(obs, returns)
if should_retune:
self.aux_train()
self.update_batch_time()
return {}
self.update_gamma(samples)
self.update_lr()
self.update_ent_coef()
......@@ -215,10 +215,7 @@ class CustomTorchPolicy(TorchPolicy):
neglogpac = neglogp_actions(pi_logits, actions)
entropy = torch.mean(pi_entropy(pi_logits))
vpredclipped = values + torch.clamp(vpred - values, -vfcliprange, vfcliprange)
vf_losses1 = torch.pow((vpred - returns), 2)
vf_losses2 = torch.pow((vpredclipped - returns), 2)
vf_loss = .5 * torch.mean(torch.max(vf_losses1, vf_losses2))
vf_loss = .5 * torch.mean(torch.pow((vpred - returns), 2))
ratio = torch.exp(neglogpac_old - neglogpac)
pg_losses1 = -advs * ratio
......@@ -231,64 +228,65 @@ class CustomTorchPolicy(TorchPolicy):
loss.backward()
if apply_grad:
nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
def retune_with_augmentation(self, obs):
def aux_train(self):
for g in self.aux_optimizer.param_groups:
g['lr'] = self.lr
nbatch_train = self.mem_limited_batch_size
aux_nbatch_train = self.config['aux_mbsize']
retune_epochs = self.config['retune_epochs']
replay_size = self.retune_selector.replay_size
replay_vf = np.empty((replay_size,), dtype=np.float32)
replay_pi = np.empty((replay_size, self.retune_selector.ac_space.n), dtype=np.float32)
# Store current value function and policy logits
for start in range(0, replay_size, nbatch_train):
end = start + nbatch_train
replay_batch = self.retune_selector.exp_replay[start:end]
replay_vf[start:end], replay_pi[start:end] = self.model.vf_pi(replay_batch,
ret_numpy=True, no_grad=True, to_torch=True)
_, replay_pi[start:end] = self.model.vf_pi(replay_batch,
ret_numpy=True, no_grad=True, to_torch=True)
optim_count = 0
# Tune vf and pi heads to older predictions with augmented observations
inds = np.arange(len(self.retune_selector.exp_replay))
for ep in range(retune_epochs):
np.random.shuffle(inds)
for start in range(0, replay_size, nbatch_train):
end = start + nbatch_train
for start in range(0, replay_size, aux_nbatch_train):
end = start + aux_nbatch_train
mbinds = inds[start:end]
optim_count += 1
apply_grad = (optim_count % self.accumulate_train_batches) == 0
slices = [self.retune_selector.exp_replay[mbinds],
self.to_tensor(replay_vf[mbinds]),
self.to_tensor(self.retune_selector.vtarg_replay[mbinds]),
self.to_tensor(replay_pi[mbinds])]
self.tune_policy(apply_grad, *slices, 0.5)
self.tune_policy(*slices)
self.retune_selector.retune_done()
def tune_policy(self, apply_grad, obs, target_vf, target_pi, retune_vf_loss_coeff):
obs_aug = np.empty(obs.shape, obs.dtype)
aug_idx = np.random.randint(3, size=len(obs))
obs_aug[aug_idx == 0] = pad_and_random_crop(obs[aug_idx == 0], 64, 10)
obs_aug[aug_idx == 1] = random_cutout_color(obs[aug_idx == 1], 10, 30)
obs_aug[aug_idx == 2] = obs[aug_idx == 2]
obs_aug = self.to_tensor(obs_aug)
def tune_policy(self, obs, target_vf, target_pi):
# obs_aug = np.empty(obs.shape, obs.dtype)
# aug_idx = np.random.randint(3, size=len(obs))
# obs_aug[aug_idx == 0] = pad_and_random_crop(obs[aug_idx == 0], 64, 10)
# obs_aug[aug_idx == 1] = random_cutout_color(obs[aug_idx == 1], 10, 30)
# obs_aug[aug_idx == 2] = obs[aug_idx == 2]
# obs_aug = self.to_tensor(obs_aug)
obs = self.to_tensor(obs)
with torch.no_grad():
tpi_log_softmax = nn.functional.log_softmax(target_pi, dim=1)
tpi_softmax = torch.exp(tpi_log_softmax)
vpred, pi_logits = self.model.vf_pi(obs_aug, ret_numpy=False, no_grad=False, to_torch=False)
vpred, pi_logits = self.model.vf_pi(obs, ret_numpy=False, no_grad=False, to_torch=False)
aux_vpred = self.model.aux_value_function()
pi_log_softmax = nn.functional.log_softmax(pi_logits, dim=1)
pi_loss = torch.mean(torch.sum(tpi_softmax * (tpi_log_softmax - pi_log_softmax) , dim=1)) # kl_div torch 1.3.1 has numerical issues
vf_loss = .5 * torch.mean(torch.pow(vpred - target_vf, 2))
aux_loss = .5 * torch.mean(torch.pow(aux_vpred - target_vf, 2))
loss = retune_vf_loss_coeff * vf_loss + pi_loss
loss = loss / self.accumulate_train_batches
loss = vf_loss + pi_loss + aux_loss
loss.backward()
if apply_grad:
self.optimizer.step()
self.optimizer.zero_grad()
self.aux_optimizer.step()
self.aux_optimizer.zero_grad()
def best_reward_model_select(self, samples):
self.timesteps_total += self.nbatch
......@@ -344,6 +342,7 @@ class CustomTorchPolicy(TorchPolicy):
"reward_deque": self.reward_deque,
"batch_end_time": self.batch_end_time,
"num_retunes": self.retune_selector.num_retunes,
# "retune_selector": self.retune_selector,
"gamma": self.gamma,
"maxrewep_lenbuf": self.maxrewep_lenbuf,
"lr": self.lr,
......@@ -361,6 +360,7 @@ class CustomTorchPolicy(TorchPolicy):
self.reward_deque = custom_state_vars["reward_deque"]
self.batch_end_time = custom_state_vars["batch_end_time"]
self.retune_selector.set_num_retunes(custom_state_vars["num_retunes"])
# self.retune_selector = custom_state_vars["num_retunes"]
self.gamma = self.adaptive_discount_tuner.gamma = custom_state_vars["gamma"]
self.maxrewep_lenbuf = custom_state_vars["maxrewep_lenbuf"]
self.lr =custom_state_vars["lr"]
......@@ -381,6 +381,10 @@ class CustomTorchPolicy(TorchPolicy):
k: v
for k, v in self.optimizer.state_dict().items()
}
weights["aux_optimizer_state"] = {
k: v
for k, v in self.aux_optimizer.state_dict().items()
}
weights["custom_state_vars"] = self.get_custom_state_vars()
return weights
......@@ -389,8 +393,13 @@ class CustomTorchPolicy(TorchPolicy):
def set_weights(self, weights):
self.set_model_weights(weights["current_weights"])
self.set_optimizer_state(weights["optimizer_state"])
self.set_aux_optimizer_state(weights["aux_optimizer_state"])
self.set_custom_state_vars(weights["custom_state_vars"])
def set_aux_optimizer_state(self, aux_optimizer_state):
aux_optimizer_state = convert_to_torch_tensor(aux_optimizer_state, device=self.device)
self.aux_optimizer.load_state_dict(aux_optimizer_state)
def set_optimizer_state(self, optimizer_state):
optimizer_state = convert_to_torch_tensor(optimizer_state, device=self.device)
self.optimizer.load_state_dict(optimizer_state)
......
import logging
from ray.rllib.agents import with_common_config
from .custom_torch_policy import CustomTorchPolicy
from .custom_torch_ppg import CustomTorchPolicy
from ray.rllib.agents.trainer_template import build_trainer
logger = logging.getLogger(__name__)
......@@ -84,12 +84,13 @@ DEFAULT_CONFIG = with_common_config({
"max_minibatch_size": 2048,
"updates_per_batch": 8,
"aux_mbsize": 512,
})
# __sphinx_doc_end__
# yapf: enable
PPOTrainer = build_trainer(
name="CustomTorchPPOAgent",
PPGTrainer = build_trainer(
name="CustomTorchPPGAgent",
default_config=DEFAULT_CONFIG,
default_policy=CustomTorchPolicy)
......@@ -95,6 +95,7 @@ class RetuneSelector:
self.skips = skips + (-skips) % nbatch
self.replay_size = replay_size + (-replay_size) % nbatch
self.exp_replay = np.empty((self.replay_size, *ob_space.shape), dtype=np.uint8)
self.vtarg_replay = np.empty((self.replay_size), dtype=np.float32)
self.batch_size = nbatch
self.batches_in_replay = self.replay_size // nbatch
......@@ -106,7 +107,7 @@ class RetuneSelector:
self.replay_index = 0
self.buffer_full = False
def update(self, obs_batch):
def update(self, obs_batch, vtarg_batch):
if self.num_retunes == 0:
return False
......@@ -117,6 +118,8 @@ class RetuneSelector:
start = self.replay_index * self.batch_size
end = start + self.batch_size
self.exp_replay[start:end] = obs_batch
self.vtarg_replay[start:end] = vtarg_batch
self.replay_index = (self.replay_index + 1) % self.batches_in_replay
self.buffer_full = self.buffer_full or (self.replay_index == 0)
......@@ -144,9 +147,8 @@ class RewardNormalizer(object):
def normalize(self, rews, news):
self.ret = self.ret * self.gamma + rews
self.ret_rms.update(self.ret)
# rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
# self.ret[np.array(news, dtype=bool)] = 0. ## Values should be True of False to set positional index
rews = np.float32(rews > 0)
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
self.ret[np.array(news, dtype=bool)] = 0. ## Values should be True of False to set positional index
return rews
class RunningMeanStd(object):
......
......@@ -29,10 +29,15 @@ def _import_custom_torch_agent():
from .custom_torch_agent.ppo import PPOTrainer
return PPOTrainer
def _import_custom_torch_ppg():
from .custom_ppg.ppg import PPGTrainer
return PPGTrainer
CUSTOM_ALGORITHMS = {
"custom/CustomRandomAgent": _import_custom_random_agent,
"RandomPolicy": _import_random_policy,
"CustomPPOAgent": _import_custom_ppo_agent,
"CustomTorchPPOAgent": _import_custom_torch_agent
"CustomTorchPPOAgent": _import_custom_torch_agent,
"CustomTorchPPGAgent": _import_custom_torch_ppg
}
procgen-ppo:
env: frame_stacked_procgen
run: CustomTorchPPOAgent
run: CustomTorchPPGAgent
disable_evaluation_worker: True
# === Stop Conditions ===
stop:
......@@ -29,12 +29,12 @@ procgen-ppo:
return_blind: 1
return_max: 10
gamma: 0.996
gamma: 0.999
lambda: 0.95
lr: 5.0e-4
# Number of SGD iterations in each outer loop
num_sgd_iter: 3
vf_loss_coeff: 0.5
num_sgd_iter: 1
vf_loss_coeff: 1.0
entropy_coeff: 0.01
clip_param: 0.2
vf_clip_param: 0.2
......@@ -48,19 +48,21 @@ procgen-ppo:
# Custom switches
retune_skips: 300000
retune_replay_size: 200000
num_retunes: 8
retune_epochs: 3
num_retunes: 20
retune_epochs: 6
standardize_rewards: True
aux_mbsize: 512
adaptive_gamma: False
final_lr: 2.0e-4
lr_schedule: 'linear'
lr_schedule: 'None'
final_entropy_coeff: 0.002
entropy_schedule: False
# Memory management, if batch size overflow, batch splitting is done to handle it
max_minibatch_size: 2048
updates_per_batch: 8
normalize_actions: False
clip_rewards: null
......@@ -74,16 +76,16 @@ procgen-ppo:
# === Settings for Model ===
model:
custom_model: impala_torch_custom
custom_model: impala_torch_ppg
custom_options:
# depths: [64, 128, 128]
# nlatents: 1024
depths: [32, 64, 64]
nlatents: 512
init_normed: False
use_layernorm: True
init_normed: True
use_layernorm: False
num_workers: 7
num_workers: 4
num_envs_per_worker: 16
rollout_fragment_length: 256
......
......@@ -92,14 +92,19 @@ class ImpalaCNN(TorchModelV2, nn.Module):
nn.init.zeros_(self.hidden_fc.bias)
self.pi_fc = nn.Linear(in_features=nlatents, out_features=num_outputs)
self.value_fc = nn.Linear(in_features=nlatents, out_features=1)
self.aux_vf = nn.Linear(in_features=nlatents, out_features=1)
if init_normed:
self.pi_fc.weight.data *= 0.1 / self.pi_fc.weight.norm(dim=1, p=2, keepdim=True)
self.value_fc.weight.data *= 0.1 / self.value_fc.weight.norm(dim=1, p=2, keepdim=True)
self.aux_vf.weight.data *= 0.1 / self.aux_vf.weight.norm(dim=1, p=2, keepdim=True)
else:
nn.init.orthogonal_(self.pi_fc.weight, gain=0.01)
nn.init.orthogonal_(self.value_fc.weight, gain=1)
nn.init.orthogonal_(self.aux_vf.weight, gain=1)
nn.init.zeros_(self.pi_fc.bias)
nn.init.zeros_(self.value_fc.bias)
nn.init.zeros_(self.aux_vf.bias)
if self.use_layernorm:
self.layernorm = nn.LayerNorm(nlatents)
......@@ -121,8 +126,9 @@ class ImpalaCNN(TorchModelV2, nn.Module):
else:
x = nn.functional.relu(x)
logits = self.pi_fc(x)
value = self.value_fc(x)
value = self.value_fc(x.detach())
self._value = value.squeeze(1)
self._aux_value = self.aux_vf(x).squeeze(1)
return logits, state
@override(TorchModelV2)
......@@ -130,6 +136,10 @@ class ImpalaCNN(TorchModelV2, nn.Module):
assert self._value is not None, "must call forward() first"
return self._value
def aux_value_function(self):
return self._aux_value
def vf_pi(self, obs, no_grad=False, ret_numpy=False, to_torch=False):
if to_torch:
obs = torch.tensor(obs).to(self.device)
......@@ -150,4 +160,4 @@ class ImpalaCNN(TorchModelV2, nn.Module):
else:
return v, pi
ModelCatalog.register_custom_model("impala_torch_custom", ImpalaCNN)
ModelCatalog.register_custom_model("impala_torch_ppg", ImpalaCNN)
......@@ -6,8 +6,8 @@ set -e
#########################################
# export EXPERIMENT_DEFAULT="experiments/impala-baseline.yaml"
export EXPERIMENT_DEFAULT="experiments/custom-torch-ppo.yaml"
# export EXPERIMENT_DEFAULT="experiments/custom-ppo.yaml"
# export EXPERIMENT_DEFAULT="experiments/custom-torch-ppo.yaml"
export EXPERIMENT_DEFAULT="experiments/custom-ppg.yaml"
export EXPERIMENT=${EXPERIMENT:-$EXPERIMENT_DEFAULT}
if [[ -z $AICROWD_IS_GRADING ]]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment