Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Dipam Chakraborty
neurips-2020-procgen-competition
Commits
0d488563
Commit
0d488563
authored
Oct 07, 2020
by
Dipam Chakraborty
Browse files
torch.distributions & aug in aux phase
parent
925d66aa
Changes
2
Hide whitespace changes
Inline
Side-by-side
algorithms/custom_ppg/custom_torch_ppg.py
View file @
0d488563
...
...
@@ -9,6 +9,7 @@ from .utils import *
import
time
torch
,
nn
=
try_import_torch
()
import
torch.distributions
as
td
class
CustomTorchPolicy
(
TorchPolicy
):
"""Example of a random policy
...
...
@@ -84,6 +85,7 @@ class CustomTorchPolicy(TorchPolicy):
self
.
ent_coef
=
config
[
'entropy_coeff'
]
self
.
last_dones
=
np
.
zeros
((
nw
*
self
.
config
[
'num_envs_per_worker'
],))
self
.
make_distr
=
dist_build
(
action_space
)
def
to_tensor
(
self
,
arr
):
return
torch
.
from_numpy
(
arr
).
to
(
self
.
device
)
...
...
@@ -135,8 +137,6 @@ class CustomTorchPolicy(TorchPolicy):
end
=
start
+
nbatch_train
values
[
start
:
end
],
_
=
self
.
model
.
vf_pi
(
samples
[
'obs'
][
start
:
end
],
ret_numpy
=
True
,
no_grad
=
True
,
to_torch
=
True
)
## GAE
mb_values
=
unroll
(
values
,
ts
)
mb_returns
=
np
.
zeros_like
(
mb_rewards
)
...
...
@@ -205,12 +205,10 @@ class CustomTorchPolicy(TorchPolicy):
for
g
in
self
.
optimizer
.
param_groups
:
g
[
'lr'
]
=
lr
# Advantages are normalized with full size batch instead of memory limited batch
# advs = returns - values
# advs = (advs - torch.mean(advs)) / (torch.std(advs) + 1e-8)
vpred
,
pi_logits
=
self
.
model
.
vf_pi
(
obs
,
ret_numpy
=
False
,
no_grad
=
False
,
to_torch
=
False
)
neglogpac
=
neglogp_actions
(
pi_logits
,
actions
)
entropy
=
torch
.
mean
(
pi_entropy
(
pi_logits
))
pd
=
self
.
make_distr
(
pi_logits
)
neglogpac
=
-
pd
.
log_prob
(
actions
[...,
None
]).
squeeze
(
1
)
entropy
=
torch
.
mean
(
pd
.
entropy
())
vf_loss
=
.
5
*
torch
.
mean
(
torch
.
pow
((
vpred
-
returns
),
2
))
...
...
@@ -262,23 +260,22 @@ class CustomTorchPolicy(TorchPolicy):
self
.
retune_selector
.
retune_done
()
def
tune_policy
(
self
,
obs
,
target_vf
,
target_pi
):
# obs_aug = np.empty(obs.shape, obs.dtype)
# aug_idx = np.random.randint(3, size=len(obs))
# obs_aug[aug_idx == 0] = pad_and_random_crop(obs[aug_idx == 0], 64, 10)
# obs_aug[aug_idx == 1] = random_cutout_color(obs[aug_idx == 1], 10, 30)
# obs_aug[aug_idx == 2] = obs[aug_idx == 2]
# obs_aug = self.to_tensor(obs_aug)
obs
=
self
.
to_tensor
(
obs
)
with
torch
.
no_grad
():
tpi_log_softmax
=
nn
.
functional
.
log_softmax
(
target_pi
,
dim
=
1
)
tpi_softmax
=
torch
.
exp
(
tpi_log_softmax
)
vpred
,
pi_logits
=
self
.
model
.
vf_pi
(
obs
,
ret_numpy
=
False
,
no_grad
=
False
,
to_torch
=
False
)
obs_aug
=
np
.
empty
(
obs
.
shape
,
obs
.
dtype
)
aug_idx
=
np
.
random
.
randint
(
6
,
size
=
len
(
obs
))
obs_aug
[
aug_idx
==
0
]
=
pad_and_random_crop
(
obs
[
aug_idx
==
0
],
64
,
10
)
obs_aug
[
aug_idx
==
1
]
=
random_cutout_color
(
obs
[
aug_idx
==
1
],
10
,
30
)
obs_aug
[
aug_idx
>=
2
]
=
obs
[
aug_idx
>=
2
]
obs_aug
=
self
.
to_tensor
(
obs_aug
)
vpred
,
pi_logits
=
self
.
model
.
vf_pi
(
obs_aug
,
ret_numpy
=
False
,
no_grad
=
False
,
to_torch
=
False
)
aux_vpred
=
self
.
model
.
aux_value_function
()
pi_log_softmax
=
nn
.
functional
.
log_softmax
(
pi_logits
,
dim
=
1
)
pi_loss
=
torch
.
mean
(
torch
.
sum
(
tpi_softmax
*
(
tpi_log_softmax
-
pi_log_softmax
)
,
dim
=
1
))
# kl_div torch 1.3.1 has numerical issues
vf_loss
=
.
5
*
torch
.
mean
(
torch
.
pow
(
vpred
-
target_vf
,
2
))
aux_loss
=
.
5
*
torch
.
mean
(
torch
.
pow
(
aux_vpred
-
target_vf
,
2
))
target_pd
=
self
.
make_distr
(
target_pi
)
pd
=
self
.
make_distr
(
pi_logits
)
pi_loss
=
td
.
kl_divergence
(
target_pd
,
pd
).
mean
()
loss
=
vf_loss
+
pi_loss
+
aux_loss
loss
.
backward
()
...
...
algorithms/custom_ppg/utils.py
View file @
0d488563
...
...
@@ -4,6 +4,15 @@ from collections import deque
from
skimage.util
import
view_as_windows
torch
,
nn
=
try_import_torch
()
import
torch.distributions
as
td
from
functools
import
partial
def
_make_categorical
(
x
,
ncat
,
shape
):
x
=
x
.
reshape
((
x
.
shape
[
0
],
shape
,
ncat
))
return
td
.
Categorical
(
logits
=
x
)
def
dist_build
(
ac_space
):
return
partial
(
_make_categorical
,
shape
=
1
,
ncat
=
ac_space
.
n
)
def
neglogp_actions
(
pi_logits
,
actions
):
return
nn
.
functional
.
cross_entropy
(
pi_logits
,
actions
,
reduction
=
'none'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment