Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Dipam Chakraborty
neurips-2020-procgen-competition
Commits
41bef706
Commit
41bef706
authored
Oct 26, 2020
by
Dipam Chakraborty
Browse files
ppo amp
parent
df3d29a1
Changes
4
Hide whitespace changes
Inline
Side-by-side
algorithms/custom_torch_agent/custom_torch_policy.py
View file @
41bef706
...
...
@@ -9,6 +9,7 @@ from .utils import *
import
time
torch
,
nn
=
try_import_torch
()
from
torch.cuda.amp
import
autocast
,
GradScaler
class
CustomTorchPolicy
(
TorchPolicy
):
"""Example of a random policy
...
...
@@ -81,6 +82,7 @@ class CustomTorchPolicy(TorchPolicy):
self
.
last_dones
=
np
.
zeros
((
nw
*
self
.
config
[
'num_envs_per_worker'
],))
self
.
save_success
=
0
self
.
retunes_completed
=
0
self
.
amp_scaler
=
GradScaler
()
def
to_tensor
(
self
,
arr
):
return
torch
.
from_numpy
(
arr
).
to
(
self
.
device
)
...
...
@@ -282,6 +284,23 @@ class CustomTorchPolicy(TorchPolicy):
with
torch
.
no_grad
():
tpi_log_softmax
=
nn
.
functional
.
log_softmax
(
target_pi
,
dim
=
1
)
tpi_softmax
=
torch
.
exp
(
tpi_log_softmax
)
if
not
self
.
config
[
'aux_phase_mixed_precision'
]:
loss
=
self
.
_retune_calc_loss
(
obs_aug
,
target_vf
,
tpi_softmax
,
tpi_log_softmax
,
retune_vf_loss_coeff
)
loss
.
backward
()
if
apply_grad
:
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
else
:
with
autocast
():
loss
=
self
.
_retune_calc_loss
(
obs_aug
,
target_vf
,
tpi_softmax
,
tpi_log_softmax
,
retune_vf_loss_coeff
)
self
.
amp_scaler
.
scale
(
loss
).
backward
()
if
apply_grad
:
self
.
amp_scaler
.
step
(
self
.
optimizer
)
self
.
amp_scaler
.
update
()
self
.
optimizer
.
zero_grad
()
def
_retune_calc_loss
(
self
,
obs_aug
,
target_vf
,
tpi_softmax
,
tpi_log_softmax
,
retune_vf_loss_coeff
):
vpred
,
pi_logits
=
self
.
model
.
vf_pi
(
obs_aug
,
ret_numpy
=
False
,
no_grad
=
False
,
to_torch
=
False
)
pi_log_softmax
=
nn
.
functional
.
log_softmax
(
pi_logits
,
dim
=
1
)
pi_loss
=
torch
.
mean
(
torch
.
sum
(
tpi_softmax
*
(
tpi_log_softmax
-
pi_log_softmax
)
,
dim
=
1
))
# kl_div torch 1.3.1 has numerical issues
...
...
@@ -289,11 +308,7 @@ class CustomTorchPolicy(TorchPolicy):
loss
=
retune_vf_loss_coeff
*
vf_loss
+
pi_loss
loss
=
loss
/
self
.
accumulate_train_batches
loss
.
backward
()
if
apply_grad
:
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
return
loss
def
best_reward_model_select
(
self
,
samples
):
self
.
timesteps_total
+=
self
.
nbatch
...
...
@@ -384,24 +399,20 @@ class CustomTorchPolicy(TorchPolicy):
k
:
v
.
cpu
().
detach
().
numpy
()
for
k
,
v
in
self
.
model
.
state_dict
().
items
()
}
# weights["optimizer_state"] = {
# k: v
# for k, v in self.optimizer.state_dict().items()
# }
# weights["custom_state_vars"] = self.get_custom_state_vars()
return
weights
@
override
(
TorchPolicy
)
def
set_weights
(
self
,
weights
):
self
.
set_model_weights
(
weights
[
"current_weights"
])
# self.set_optimizer_state(weights["optimizer_state"])
# self.set_custom_state_vars(weights["custom_state_vars"])
def
set_optimizer_state
(
self
,
optimizer_state
):
def
set_optimizer_state
(
self
,
optimizer_state
,
amp_scaler_state
):
optimizer_state
=
convert_to_torch_tensor
(
optimizer_state
,
device
=
self
.
device
)
self
.
optimizer
.
load_state_dict
(
optimizer_state
)
amp_scaler_state
=
convert_to_torch_tensor
(
amp_scaler_state
,
device
=
self
.
device
)
self
.
amp_scaler
.
load_state_dict
(
amp_scaler_state
)
def
set_model_weights
(
self
,
model_weights
):
model_weights
=
convert_to_torch_tensor
(
model_weights
,
device
=
self
.
device
)
self
.
model
.
load_state_dict
(
model_weights
)
\ No newline at end of file
algorithms/custom_torch_agent/custom_trainer_template.py
View file @
41bef706
...
...
@@ -194,43 +194,8 @@ def build_trainer(name,
policy
=
Trainer
.
get_policy
(
self
)
state
[
"custom_state_vars"
]
=
policy
.
get_custom_state_vars
()
state
[
"optimizer_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
optimizer
.
state_dict
().
items
()}
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances
# save_success = False
# max_size = 3_700_000_000
# if policy.exp_replay.nbytes < max_size:
# state["replay_buffer"] = policy.exp_replay
# state["buffer_saved"] = 1
# policy.save_success = 1
# save_success = True
# elif policy.exp_replay.shape[-1] == 6: # only for frame stack = 2
# eq = np.all(policy.exp_replay[1:,...,:3] == policy.exp_replay[:-1,...,-3:], axis=(-3,-2,-1))
# non_eq = np.where(1 - eq)
# images_non_eq = policy.exp_replay[non_eq]
# images_last = policy.exp_replay[-1,...,-3:]
# images_first = policy.exp_replay[0,...,:3]
# if policy.exp_replay[1:,...,:3].nbytes < max_size:
# state["sliced_buffer"] = policy.exp_replay[1:,...,:3]
# state["buffer_saved"] = 2
# policy.save_success = 2
# save_success = True
# else:
# comp = compress(policy.exp_replay[1:,...,:3].copy(), level=9)
# if getsizeof(comp) < max_size:
# state["compressed_buffer"] = comp
# state["buffer_saved"] = 3
# policy.save_success = 3
# save_success = True
# if save_success:
# state["matched_frame_data"] = [non_eq, images_non_eq, images_last, images_first]
# if not save_success:
# state["buffer_saved"] = -1
# policy.save_success = -1
# print("####################### BUFFER SAVE FAILED #########################")
# else:
# state["retune_selector"] = policy.retune_selector
state
[
"amp_scaler_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
amp_scaler
.
state_dict
().
items
()}
if
self
.
train_exec_impl
:
state
[
"train_exec_impl"
]
=
(
self
.
train_exec_impl
.
shared_metrics
.
get
().
save
())
...
...
@@ -240,28 +205,8 @@ def build_trainer(name,
Trainer
.
__setstate__
(
self
,
state
)
policy
=
Trainer
.
get_policy
(
self
)
self
.
state
=
state
[
"trainer_state"
].
copy
()
policy
.
set_optimizer_state
(
state
[
"optimizer_state"
])
policy
.
set_optimizer_state
(
state
[
"optimizer_state"
]
,
state
[
"amp_scaler_state"
]
)
policy
.
set_custom_state_vars
(
state
[
"custom_state_vars"
])
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances
# buffer_saved = state.get("buffer_saved", -1)
# policy.save_success = buffer_saved
# if buffer_saved == 1:
# policy.exp_replay = state["replay_buffer"]
# elif buffer_saved > 1:
# non_eq, images_non_eq, images_last, images_first = state["matched_frame_data"]
# policy.exp_replay[non_eq] = images_non_eq
# policy.exp_replay[-1,...,-3:] = images_last
# policy.exp_replay[0,...,:3] = images_first
# if buffer_saved == 2:
# policy.exp_replay[1:,...,:3] = state["sliced_buffer"]
# elif buffer_saved == 3:
# ts = policy.exp_replay[1:,...,:3].shape
# dt = policy.exp_replay.dtype
# decomp = decompress(state["compressed_buffer"])
# policy.exp_replay[1:,...,:3] = np.array(np.frombuffer(decomp, dtype=dt).reshape(ts))
# if buffer_saved > 0:
# policy.retune_selector = state["retune_selector"]
if
self
.
train_exec_impl
:
self
.
train_exec_impl
.
shared_metrics
.
get
().
restore
(
...
...
algorithms/custom_torch_agent/ppo.py
View file @
41bef706
...
...
@@ -88,6 +88,7 @@ DEFAULT_CONFIG = with_common_config({
"updates_per_batch"
:
8
,
"scale_reward"
:
1.0
,
"return_reset"
:
True
,
"aux_phase_mixed_precision"
:
False
,
})
# __sphinx_doc_end__
# yapf: enable
...
...
experiments/custom-torch-ppo.yaml
View file @
41bef706
...
...
@@ -53,6 +53,7 @@ procgen-ppo:
standardize_rewards
:
True
scale_reward
:
1.0
return_reset
:
False
aux_phase_mixed_precision
:
True
adaptive_gamma
:
False
final_lr
:
5.0e-5
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment