Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Dipam Chakraborty
neurips-2020-procgen-competition
Commits
7106fa78
Commit
7106fa78
authored
Oct 26, 2020
by
Dipam Chakraborty
Browse files
init for training process only
parent
41bef706
Changes
6
Hide whitespace changes
Inline
Side-by-side
algorithms/custom_ppg/custom_torch_ppg.py
View file @
7106fa78
...
...
@@ -22,6 +22,8 @@ class CustomTorchPolicy(TorchPolicy):
def
__init__
(
self
,
observation_space
,
action_space
,
config
):
self
.
config
=
config
self
.
acion_space
=
action_space
self
.
observation_space
=
observation_space
self
.
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
dist_class
,
logit_dim
=
ModelCatalog
.
get_action_dist
(
...
...
@@ -44,8 +46,12 @@ class CustomTorchPolicy(TorchPolicy):
loss
=
None
,
action_distribution_class
=
dist_class
,
)
self
.
framework
=
"torch"
def
init_training
(
self
):
""" Init once only for the policy - Surely there should be a bette way to do this """
aux_params
=
set
(
self
.
model
.
aux_vf
.
parameters
())
value_params
=
set
(
self
.
model
.
value_fc
.
parameters
())
network_params
=
set
(
self
.
model
.
parameters
())
...
...
@@ -80,7 +86,7 @@ class CustomTorchPolicy(TorchPolicy):
print
(
"WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY"
)
print
(
"#################################################"
)
replay_shape
=
(
n_pi
,
nsteps
,
nenvs
)
self
.
retune_selector
=
RetuneSelector
(
nenvs
,
observation_space
,
action_space
,
replay_shape
,
self
.
retune_selector
=
RetuneSelector
(
nenvs
,
self
.
observation_space
,
self
.
action_space
,
replay_shape
,
skips
=
self
.
config
[
'skips'
],
n_pi
=
n_pi
,
num_retunes
=
self
.
config
[
'num_retunes'
],
...
...
@@ -93,11 +99,11 @@ class CustomTorchPolicy(TorchPolicy):
self
.
gamma
=
self
.
config
[
'gamma'
]
self
.
adaptive_discount_tuner
=
AdaptiveDiscountTuner
(
self
.
gamma
,
momentum
=
0.98
,
eplenmult
=
3
)
self
.
lr
=
config
[
'lr'
]
self
.
ent_coef
=
config
[
'entropy_coeff'
]
self
.
lr
=
self
.
config
[
'lr'
]
self
.
ent_coef
=
self
.
config
[
'entropy_coeff'
]
self
.
last_dones
=
np
.
zeros
((
nw
*
self
.
config
[
'num_envs_per_worker'
],))
self
.
make_distr
=
dist_build
(
action_space
)
self
.
make_distr
=
dist_build
(
self
.
action_space
)
self
.
retunes_completed
=
0
self
.
amp_scaler
=
GradScaler
()
...
...
algorithms/custom_ppg/custom_trainer_template.py
View file @
7106fa78
...
...
@@ -139,6 +139,9 @@ def build_trainer(name,
**
optimizer_config
)
if
after_init
:
after_init
(
self
)
policy
=
Trainer
.
get_policy
(
self
)
policy
.
init_training
()
@
override
(
Trainer
)
def
_train
(
self
):
...
...
@@ -192,11 +195,14 @@ def build_trainer(name,
state
=
Trainer
.
__getstate__
(
self
)
state
[
"trainer_state"
]
=
self
.
state
.
copy
()
policy
=
Trainer
.
get_policy
(
self
)
state
[
"custom_state_vars"
]
=
policy
.
get_custom_state_vars
()
state
[
"optimizer_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
optimizer
.
state_dict
().
items
()}
state
[
"aux_optimizer_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
aux_optimizer
.
state_dict
().
items
()}
state
[
"value_optimizer_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
value_optimizer
.
state_dict
().
items
()}
state
[
"amp_scaler_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
amp_scaler
.
state_dict
().
items
()}
try
:
state
[
"custom_state_vars"
]
=
policy
.
get_custom_state_vars
()
state
[
"optimizer_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
optimizer
.
state_dict
().
items
()}
state
[
"aux_optimizer_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
aux_optimizer
.
state_dict
().
items
()}
state
[
"value_optimizer_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
value_optimizer
.
state_dict
().
items
()}
state
[
"amp_scaler_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
amp_scaler
.
state_dict
().
items
()}
except
:
print
(
"################# WARNING: SAVING STATE VARS AND OPTIMIZER FAILED ################"
)
if
self
.
train_exec_impl
:
state
[
"train_exec_impl"
]
=
(
...
...
@@ -207,9 +213,12 @@ def build_trainer(name,
Trainer
.
__setstate__
(
self
,
state
)
policy
=
Trainer
.
get_policy
(
self
)
self
.
state
=
state
[
"trainer_state"
].
copy
()
policy
.
set_optimizer_state
(
state
[
"optimizer_state"
],
state
[
"aux_optimizer_state"
],
state
[
"value_optimizer_state"
],
state
[
"amp_scaler_state"
])
policy
.
set_custom_state_vars
(
state
[
"custom_state_vars"
])
try
:
policy
.
set_optimizer_state
(
state
[
"optimizer_state"
],
state
[
"aux_optimizer_state"
],
state
[
"value_optimizer_state"
],
state
[
"amp_scaler_state"
])
policy
.
set_custom_state_vars
(
state
[
"custom_state_vars"
])
except
:
print
(
"################# WARNING: LOADING STATE VARS AND OPTIMIZER FAILED ################"
)
if
self
.
train_exec_impl
:
self
.
train_exec_impl
.
shared_metrics
.
get
().
restore
(
...
...
algorithms/custom_torch_agent/custom_torch_policy.py
View file @
7106fa78
...
...
@@ -21,6 +21,8 @@ class CustomTorchPolicy(TorchPolicy):
def
__init__
(
self
,
observation_space
,
action_space
,
config
):
self
.
config
=
config
self
.
acion_space
=
action_space
self
.
observation_space
=
observation_space
self
.
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
dist_class
,
logit_dim
=
ModelCatalog
.
get_action_dist
(
...
...
@@ -43,7 +45,11 @@ class CustomTorchPolicy(TorchPolicy):
loss
=
None
,
action_distribution_class
=
dist_class
,
)
self
.
framework
=
"torch"
def
init_training
(
self
):
self
.
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
0.001
)
self
.
max_reward
=
self
.
config
[
'env_config'
][
'return_max'
]
self
.
rewnorm
=
RewardNormalizer
(
cliprew
=
self
.
max_reward
)
## TODO: Might need to go to custom state
...
...
@@ -64,11 +70,11 @@ class CustomTorchPolicy(TorchPolicy):
print
(
"#################################################"
)
print
(
"WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY"
)
print
(
"#################################################"
)
self
.
retune_selector
=
RetuneSelector
(
self
.
nbatch
,
observation_space
,
action_space
,
self
.
retune_selector
=
RetuneSelector
(
self
.
nbatch
,
self
.
observation_space
,
self
.
action_space
,
skips
=
self
.
config
[
'retune_skips'
],
replay_size
=
self
.
config
[
'retune_replay_size'
],
num_retunes
=
self
.
config
[
'num_retunes'
])
self
.
exp_replay
=
np
.
zeros
((
self
.
retune_selector
.
replay_size
,
*
observation_space
.
shape
),
dtype
=
np
.
uint8
)
self
.
exp_replay
=
np
.
empty
((
self
.
retune_selector
.
replay_size
,
*
self
.
observation_space
.
shape
),
dtype
=
np
.
uint8
)
self
.
target_timesteps
=
8_000_000
self
.
buffer_time
=
20
# TODO: Could try to do a median or mean time step check instead
self
.
max_time
=
10000000000000
# ignore timekeeping because spot instances are messing it up
...
...
@@ -76,8 +82,8 @@ class CustomTorchPolicy(TorchPolicy):
self
.
gamma
=
self
.
config
[
'gamma'
]
self
.
adaptive_discount_tuner
=
AdaptiveDiscountTuner
(
self
.
gamma
,
momentum
=
0.98
,
eplenmult
=
3
)
self
.
lr
=
config
[
'lr'
]
self
.
ent_coef
=
config
[
'entropy_coeff'
]
self
.
lr
=
self
.
config
[
'lr'
]
self
.
ent_coef
=
self
.
config
[
'entropy_coeff'
]
self
.
last_dones
=
np
.
zeros
((
nw
*
self
.
config
[
'num_envs_per_worker'
],))
self
.
save_success
=
0
...
...
@@ -270,7 +276,6 @@ class CustomTorchPolicy(TorchPolicy):
self
.
to_tensor
(
replay_pi
[
mbinds
])]
self
.
tune_policy
(
apply_grad
,
*
slices
,
0.5
)
self
.
exp_replay
.
fill
(
0
)
self
.
retunes_completed
+=
1
self
.
retune_selector
.
retune_done
()
...
...
algorithms/custom_torch_agent/custom_trainer_template.py
View file @
7106fa78
...
...
@@ -139,6 +139,9 @@ def build_trainer(name,
**
optimizer_config
)
if
after_init
:
after_init
(
self
)
policy
=
Trainer
.
get_policy
(
self
)
policy
.
init_training
()
@
override
(
Trainer
)
def
_train
(
self
):
...
...
@@ -192,9 +195,12 @@ def build_trainer(name,
state
=
Trainer
.
__getstate__
(
self
)
state
[
"trainer_state"
]
=
self
.
state
.
copy
()
policy
=
Trainer
.
get_policy
(
self
)
state
[
"custom_state_vars"
]
=
policy
.
get_custom_state_vars
()
state
[
"optimizer_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
optimizer
.
state_dict
().
items
()}
state
[
"amp_scaler_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
amp_scaler
.
state_dict
().
items
()}
try
:
state
[
"custom_state_vars"
]
=
policy
.
get_custom_state_vars
()
state
[
"optimizer_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
optimizer
.
state_dict
().
items
()}
state
[
"amp_scaler_state"
]
=
{
k
:
v
for
k
,
v
in
policy
.
amp_scaler
.
state_dict
().
items
()}
except
:
print
(
"################# WARNING: SAVING STATE VARS AND OPTIMIZER FAILED ################"
)
if
self
.
train_exec_impl
:
state
[
"train_exec_impl"
]
=
(
...
...
@@ -205,8 +211,11 @@ def build_trainer(name,
Trainer
.
__setstate__
(
self
,
state
)
policy
=
Trainer
.
get_policy
(
self
)
self
.
state
=
state
[
"trainer_state"
].
copy
()
policy
.
set_optimizer_state
(
state
[
"optimizer_state"
],
state
[
"amp_scaler_state"
])
policy
.
set_custom_state_vars
(
state
[
"custom_state_vars"
])
try
:
policy
.
set_optimizer_state
(
state
[
"optimizer_state"
],
state
[
"amp_scaler_state"
])
policy
.
set_custom_state_vars
(
state
[
"custom_state_vars"
])
except
:
print
(
"################# WARNING: LOADING STATE VARS AND OPTIMIZER FAILED ################"
)
if
self
.
train_exec_impl
:
self
.
train_exec_impl
.
shared_metrics
.
get
().
restore
(
...
...
experiments/custom-ppg.yaml
View file @
7106fa78
...
...
@@ -8,7 +8,7 @@ procgen-ppo:
time_total_s
:
7200
# === Settings for Checkpoints ===
checkpoint_freq
:
1
checkpoint_freq
:
1
00
checkpoint_at_end
:
True
keep_checkpoints_num
:
5
...
...
experiments/custom-torch-ppo.yaml
View file @
7106fa78
...
...
@@ -46,9 +46,9 @@ procgen-ppo:
no_done_at_end
:
False
# Custom switches
retune_skips
:
10
0000
retune_skips
:
5
0000
retune_replay_size
:
200000
num_retunes
:
2
3
num_retunes
:
2
8
retune_epochs
:
3
standardize_rewards
:
True
scale_reward
:
1.0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment