Commit f2a86f1f authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

ppg grad clip

parent f07abc5e
......@@ -217,41 +217,21 @@ class CustomTorchPolicy(TorchPolicy):
ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs):
if not self.config['pi_phase_mixed_precision']:
loss, vf_loss = self._calc_pi_vf_loss(apply_grad, num_accumulate,
cliprange, vfcliprange, max_grad_norm,
ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs)
loss.backward()
vf_loss.backward()
if apply_grad:
self.optimizer.step()
self.optimizer.zero_grad()
if not self.config['single_optimizer']:
self.value_optimizer.step()
self.value_optimizer.zero_grad()
else:
with autocast():
loss, vf_loss = self._calc_pi_vf_loss(apply_grad, num_accumulate,
cliprange, vfcliprange, max_grad_norm,
ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs)
self.amp_scaler.scale(loss).backward(retain_graph=True)
self.amp_scaler.scale(vf_loss).backward()
if apply_grad:
self.amp_scaler.step(self.optimizer)
if not self.config['single_optimizer']:
self.amp_scaler.step(self.value_optimizer)
self.amp_scaler.update()
self.optimizer.zero_grad()
if not self.config['single_optimizer']:
self.value_optimizer.zero_grad()
loss, vf_loss = self._calc_pi_vf_loss(apply_grad, num_accumulate,
cliprange, vfcliprange, max_grad_norm,
ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs)
loss.backward()
vf_loss.backward()
if apply_grad:
if self.config['grad_clip'] is not None:
nn.utils.clip_grad_norm_(self.model.parameters(), self.config['grad_clip'])
self.optimizer.step()
self.optimizer.zero_grad()
if not self.config['single_optimizer']:
self.value_optimizer.step()
self.value_optimizer.zero_grad()
def _calc_pi_vf_loss(self, apply_grad, num_accumulate,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment