Commit cca87e73 authored by Chakraborty's avatar Chakraborty
Browse files

merge ppg

parents cbeaecd6 606189ba
...@@ -22,6 +22,8 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -22,6 +22,8 @@ class CustomTorchPolicy(TorchPolicy):
def __init__(self, observation_space, action_space, config): def __init__(self, observation_space, action_space, config):
self.config = config self.config = config
self.acion_space = action_space
self.observation_space = observation_space
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dist_class, logit_dim = ModelCatalog.get_action_dist( dist_class, logit_dim = ModelCatalog.get_action_dist(
...@@ -44,8 +46,12 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -44,8 +46,12 @@ class CustomTorchPolicy(TorchPolicy):
loss=None, loss=None,
action_distribution_class=dist_class, action_distribution_class=dist_class,
) )
self.framework = "torch" self.framework = "torch"
def init_training(self):
""" Init once only for the policy - Surely there should be a bette way to do this """
aux_params = set(self.model.aux_vf.parameters()) aux_params = set(self.model.aux_vf.parameters())
value_params = set(self.model.value_fc.parameters()) value_params = set(self.model.value_fc.parameters())
network_params = set(self.model.parameters()) network_params = set(self.model.parameters())
...@@ -80,7 +86,7 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -80,7 +86,7 @@ class CustomTorchPolicy(TorchPolicy):
print("WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY") print("WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY")
print("#################################################") print("#################################################")
replay_shape = (n_pi, nsteps, nenvs) replay_shape = (n_pi, nsteps, nenvs)
self.retune_selector = RetuneSelector(nenvs, observation_space, action_space, replay_shape, self.retune_selector = RetuneSelector(nenvs, self.observation_space, self.action_space, replay_shape,
skips = self.config['skips'], skips = self.config['skips'],
n_pi = n_pi, n_pi = n_pi,
num_retunes = self.config['num_retunes'], num_retunes = self.config['num_retunes'],
...@@ -93,11 +99,11 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -93,11 +99,11 @@ class CustomTorchPolicy(TorchPolicy):
self.gamma = self.config['gamma'] self.gamma = self.config['gamma']
self.adaptive_discount_tuner = AdaptiveDiscountTuner(self.gamma, momentum=0.98, eplenmult=3) self.adaptive_discount_tuner = AdaptiveDiscountTuner(self.gamma, momentum=0.98, eplenmult=3)
self.lr = config['lr'] self.lr = self.config['lr']
self.ent_coef = config['entropy_coeff'] self.ent_coef = self.config['entropy_coeff']
self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],)) self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],))
self.make_distr = dist_build(action_space) self.make_distr = dist_build(self.action_space)
self.retunes_completed = 0 self.retunes_completed = 0
self.amp_scaler = GradScaler() self.amp_scaler = GradScaler()
......
...@@ -139,6 +139,9 @@ def build_trainer(name, ...@@ -139,6 +139,9 @@ def build_trainer(name,
**optimizer_config) **optimizer_config)
if after_init: if after_init:
after_init(self) after_init(self)
policy = Trainer.get_policy(self)
policy.init_training()
@override(Trainer) @override(Trainer)
def _train(self): def _train(self):
...@@ -192,11 +195,14 @@ def build_trainer(name, ...@@ -192,11 +195,14 @@ def build_trainer(name,
state = Trainer.__getstate__(self) state = Trainer.__getstate__(self)
state["trainer_state"] = self.state.copy() state["trainer_state"] = self.state.copy()
policy = Trainer.get_policy(self) policy = Trainer.get_policy(self)
state["custom_state_vars"] = policy.get_custom_state_vars() try:
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()} state["custom_state_vars"] = policy.get_custom_state_vars()
state["aux_optimizer_state"] = {k: v for k, v in policy.aux_optimizer.state_dict().items()} state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()}
state["value_optimizer_state"] = {k: v for k, v in policy.value_optimizer.state_dict().items()} state["aux_optimizer_state"] = {k: v for k, v in policy.aux_optimizer.state_dict().items()}
state["amp_scaler_state"] = {k: v for k, v in policy.amp_scaler.state_dict().items()} state["value_optimizer_state"] = {k: v for k, v in policy.value_optimizer.state_dict().items()}
state["amp_scaler_state"] = {k: v for k, v in policy.amp_scaler.state_dict().items()}
except:
print("################# WARNING: SAVING STATE VARS AND OPTIMIZER FAILED ################")
if self.train_exec_impl: if self.train_exec_impl:
state["train_exec_impl"] = ( state["train_exec_impl"] = (
...@@ -207,9 +213,12 @@ def build_trainer(name, ...@@ -207,9 +213,12 @@ def build_trainer(name,
Trainer.__setstate__(self, state) Trainer.__setstate__(self, state)
policy = Trainer.get_policy(self) policy = Trainer.get_policy(self)
self.state = state["trainer_state"].copy() self.state = state["trainer_state"].copy()
policy.set_optimizer_state(state["optimizer_state"], state["aux_optimizer_state"], try:
state["value_optimizer_state"], state["amp_scaler_state"]) policy.set_optimizer_state(state["optimizer_state"], state["aux_optimizer_state"],
policy.set_custom_state_vars(state["custom_state_vars"]) state["value_optimizer_state"], state["amp_scaler_state"])
policy.set_custom_state_vars(state["custom_state_vars"])
except:
print("################# WARNING: LOADING STATE VARS AND OPTIMIZER FAILED ################")
if self.train_exec_impl: if self.train_exec_impl:
self.train_exec_impl.shared_metrics.get().restore( self.train_exec_impl.shared_metrics.get().restore(
......
...@@ -9,6 +9,7 @@ from .utils import * ...@@ -9,6 +9,7 @@ from .utils import *
import time import time
torch, nn = try_import_torch() torch, nn = try_import_torch()
from torch.cuda.amp import autocast, GradScaler
class CustomTorchPolicy(TorchPolicy): class CustomTorchPolicy(TorchPolicy):
"""Example of a random policy """Example of a random policy
...@@ -20,6 +21,8 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -20,6 +21,8 @@ class CustomTorchPolicy(TorchPolicy):
def __init__(self, observation_space, action_space, config): def __init__(self, observation_space, action_space, config):
self.config = config self.config = config
self.acion_space = action_space
self.observation_space = observation_space
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dist_class, logit_dim = ModelCatalog.get_action_dist( dist_class, logit_dim = ModelCatalog.get_action_dist(
...@@ -42,7 +45,11 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -42,7 +45,11 @@ class CustomTorchPolicy(TorchPolicy):
loss=None, loss=None,
action_distribution_class=dist_class, action_distribution_class=dist_class,
) )
self.framework = "torch" self.framework = "torch"
def init_training(self):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
self.max_reward = self.config['env_config']['return_max'] self.max_reward = self.config['env_config']['return_max']
self.rewnorm = RewardNormalizer(cliprew=self.max_reward) ## TODO: Might need to go to custom state self.rewnorm = RewardNormalizer(cliprew=self.max_reward) ## TODO: Might need to go to custom state
...@@ -63,24 +70,25 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -63,24 +70,25 @@ class CustomTorchPolicy(TorchPolicy):
print("#################################################") print("#################################################")
print("WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY") print("WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY")
print("#################################################") print("#################################################")
self.retune_selector = RetuneSelector(self.nbatch, observation_space, action_space, self.retune_selector = RetuneSelector(self.nbatch, self.observation_space, self.action_space,
skips = self.config['retune_skips'], skips = self.config['retune_skips'],
replay_size = self.config['retune_replay_size'], replay_size = self.config['retune_replay_size'],
num_retunes = self.config['num_retunes']) num_retunes = self.config['num_retunes'])
self.exp_replay = np.zeros((self.retune_selector.replay_size, *observation_space.shape), dtype=np.uint8) self.exp_replay = np.empty((self.retune_selector.replay_size, *self.observation_space.shape), dtype=np.uint8)
self.target_timesteps = 8_000_000 self.target_timesteps = 8_000_000
self.buffer_time = 20 # TODO: Could try to do a median or mean time step check instead self.buffer_time = 20 # TODO: Could try to do a median or mean time step check instead
self.max_time = 10000000000000 # ignore timekeeping because spot instances are messing it up self.max_time = self.config['max_time']
self.maxrewep_lenbuf = deque(maxlen=100) self.maxrewep_lenbuf = deque(maxlen=100)
self.gamma = self.config['gamma'] self.gamma = self.config['gamma']
self.adaptive_discount_tuner = AdaptiveDiscountTuner(self.gamma, momentum=0.98, eplenmult=3) self.adaptive_discount_tuner = AdaptiveDiscountTuner(self.gamma, momentum=0.98, eplenmult=3)
self.lr = config['lr'] self.lr = self.config['lr']
self.ent_coef = config['entropy_coeff'] self.ent_coef = self.config['entropy_coeff']
self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],)) self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],))
self.save_success = 0 self.save_success = 0
self.retunes_completed = 0 self.retunes_completed = 0
self.amp_scaler = GradScaler()
def to_tensor(self, arr): def to_tensor(self, arr):
return torch.from_numpy(arr).to(self.device) return torch.from_numpy(arr).to(self.device)
...@@ -268,7 +276,6 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -268,7 +276,6 @@ class CustomTorchPolicy(TorchPolicy):
self.to_tensor(replay_pi[mbinds])] self.to_tensor(replay_pi[mbinds])]
self.tune_policy(apply_grad, *slices, 0.5) self.tune_policy(apply_grad, *slices, 0.5)
self.exp_replay.fill(0)
self.retunes_completed += 1 self.retunes_completed += 1
self.retune_selector.retune_done() self.retune_selector.retune_done()
...@@ -282,6 +289,23 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -282,6 +289,23 @@ class CustomTorchPolicy(TorchPolicy):
with torch.no_grad(): with torch.no_grad():
tpi_log_softmax = nn.functional.log_softmax(target_pi, dim=1) tpi_log_softmax = nn.functional.log_softmax(target_pi, dim=1)
tpi_softmax = torch.exp(tpi_log_softmax) tpi_softmax = torch.exp(tpi_log_softmax)
if not self.config['aux_phase_mixed_precision']:
loss = self._retune_calc_loss(obs_aug, target_vf, tpi_softmax, tpi_log_softmax, retune_vf_loss_coeff)
loss.backward()
if apply_grad:
self.optimizer.step()
self.optimizer.zero_grad()
else:
with autocast():
loss = self._retune_calc_loss(obs_aug, target_vf, tpi_softmax, tpi_log_softmax, retune_vf_loss_coeff)
self.amp_scaler.scale(loss).backward()
if apply_grad:
self.amp_scaler.step(self.optimizer)
self.amp_scaler.update()
self.optimizer.zero_grad()
def _retune_calc_loss(self, obs_aug, target_vf, tpi_softmax, tpi_log_softmax, retune_vf_loss_coeff):
vpred, pi_logits = self.model.vf_pi(obs_aug, ret_numpy=False, no_grad=False, to_torch=False) vpred, pi_logits = self.model.vf_pi(obs_aug, ret_numpy=False, no_grad=False, to_torch=False)
pi_log_softmax = nn.functional.log_softmax(pi_logits, dim=1) pi_log_softmax = nn.functional.log_softmax(pi_logits, dim=1)
pi_loss = torch.mean(torch.sum(tpi_softmax * (tpi_log_softmax - pi_log_softmax) , dim=1)) # kl_div torch 1.3.1 has numerical issues pi_loss = torch.mean(torch.sum(tpi_softmax * (tpi_log_softmax - pi_log_softmax) , dim=1)) # kl_div torch 1.3.1 has numerical issues
...@@ -289,11 +313,7 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -289,11 +313,7 @@ class CustomTorchPolicy(TorchPolicy):
loss = retune_vf_loss_coeff * vf_loss + pi_loss loss = retune_vf_loss_coeff * vf_loss + pi_loss
loss = loss / self.accumulate_train_batches loss = loss / self.accumulate_train_batches
return loss
loss.backward()
if apply_grad:
self.optimizer.step()
self.optimizer.zero_grad()
def best_reward_model_select(self, samples): def best_reward_model_select(self, samples):
self.timesteps_total += self.nbatch self.timesteps_total += self.nbatch
...@@ -384,24 +404,20 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -384,24 +404,20 @@ class CustomTorchPolicy(TorchPolicy):
k: v.cpu().detach().numpy() k: v.cpu().detach().numpy()
for k, v in self.model.state_dict().items() for k, v in self.model.state_dict().items()
} }
# weights["optimizer_state"] = {
# k: v
# for k, v in self.optimizer.state_dict().items()
# }
# weights["custom_state_vars"] = self.get_custom_state_vars()
return weights return weights
@override(TorchPolicy) @override(TorchPolicy)
def set_weights(self, weights): def set_weights(self, weights):
self.set_model_weights(weights["current_weights"]) self.set_model_weights(weights["current_weights"])
# self.set_optimizer_state(weights["optimizer_state"])
# self.set_custom_state_vars(weights["custom_state_vars"])
def set_optimizer_state(self, optimizer_state): def set_optimizer_state(self, optimizer_state, amp_scaler_state):
optimizer_state = convert_to_torch_tensor(optimizer_state, device=self.device) optimizer_state = convert_to_torch_tensor(optimizer_state, device=self.device)
self.optimizer.load_state_dict(optimizer_state) self.optimizer.load_state_dict(optimizer_state)
amp_scaler_state = convert_to_torch_tensor(amp_scaler_state, device=self.device)
self.amp_scaler.load_state_dict(amp_scaler_state)
def set_model_weights(self, model_weights): def set_model_weights(self, model_weights):
model_weights = convert_to_torch_tensor(model_weights, device=self.device) model_weights = convert_to_torch_tensor(model_weights, device=self.device)
self.model.load_state_dict(model_weights) self.model.load_state_dict(model_weights)
\ No newline at end of file
...@@ -139,6 +139,9 @@ def build_trainer(name, ...@@ -139,6 +139,9 @@ def build_trainer(name,
**optimizer_config) **optimizer_config)
if after_init: if after_init:
after_init(self) after_init(self)
policy = Trainer.get_policy(self)
policy.init_training()
@override(Trainer) @override(Trainer)
def _train(self): def _train(self):
...@@ -192,45 +195,13 @@ def build_trainer(name, ...@@ -192,45 +195,13 @@ def build_trainer(name,
state = Trainer.__getstate__(self) state = Trainer.__getstate__(self)
state["trainer_state"] = self.state.copy() state["trainer_state"] = self.state.copy()
policy = Trainer.get_policy(self) policy = Trainer.get_policy(self)
state["custom_state_vars"] = policy.get_custom_state_vars() try:
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()} state["custom_state_vars"] = policy.get_custom_state_vars()
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()}
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances state["amp_scaler_state"] = {k: v for k, v in policy.amp_scaler.state_dict().items()}
# save_success = False except:
# max_size = 3_700_000_000 print("################# WARNING: SAVING STATE VARS AND OPTIMIZER FAILED ################")
# if policy.exp_replay.nbytes < max_size:
# state["replay_buffer"] = policy.exp_replay
# state["buffer_saved"] = 1
# policy.save_success = 1
# save_success = True
# elif policy.exp_replay.shape[-1] == 6: # only for frame stack = 2
# eq = np.all(policy.exp_replay[1:,...,:3] == policy.exp_replay[:-1,...,-3:], axis=(-3,-2,-1))
# non_eq = np.where(1 - eq)
# images_non_eq = policy.exp_replay[non_eq]
# images_last = policy.exp_replay[-1,...,-3:]
# images_first = policy.exp_replay[0,...,:3]
# if policy.exp_replay[1:,...,:3].nbytes < max_size:
# state["sliced_buffer"] = policy.exp_replay[1:,...,:3]
# state["buffer_saved"] = 2
# policy.save_success = 2
# save_success = True
# else:
# comp = compress(policy.exp_replay[1:,...,:3].copy(), level=9)
# if getsizeof(comp) < max_size:
# state["compressed_buffer"] = comp
# state["buffer_saved"] = 3
# policy.save_success = 3
# save_success = True
# if save_success:
# state["matched_frame_data"] = [non_eq, images_non_eq, images_last, images_first]
# if not save_success:
# state["buffer_saved"] = -1
# policy.save_success = -1
# print("####################### BUFFER SAVE FAILED #########################")
# else:
# state["retune_selector"] = policy.retune_selector
if self.train_exec_impl: if self.train_exec_impl:
state["train_exec_impl"] = ( state["train_exec_impl"] = (
self.train_exec_impl.shared_metrics.get().save()) self.train_exec_impl.shared_metrics.get().save())
...@@ -240,28 +211,11 @@ def build_trainer(name, ...@@ -240,28 +211,11 @@ def build_trainer(name,
Trainer.__setstate__(self, state) Trainer.__setstate__(self, state)
policy = Trainer.get_policy(self) policy = Trainer.get_policy(self)
self.state = state["trainer_state"].copy() self.state = state["trainer_state"].copy()
policy.set_optimizer_state(state["optimizer_state"]) try:
policy.set_custom_state_vars(state["custom_state_vars"]) policy.set_optimizer_state(state["optimizer_state"], state["amp_scaler_state"])
policy.set_custom_state_vars(state["custom_state_vars"])
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances except:
# buffer_saved = state.get("buffer_saved", -1) print("################# WARNING: LOADING STATE VARS AND OPTIMIZER FAILED ################")
# policy.save_success = buffer_saved
# if buffer_saved == 1:
# policy.exp_replay = state["replay_buffer"]
# elif buffer_saved > 1:
# non_eq, images_non_eq, images_last, images_first = state["matched_frame_data"]
# policy.exp_replay[non_eq] = images_non_eq
# policy.exp_replay[-1,...,-3:] = images_last
# policy.exp_replay[0,...,:3] = images_first
# if buffer_saved == 2:
# policy.exp_replay[1:,...,:3] = state["sliced_buffer"]
# elif buffer_saved == 3:
# ts = policy.exp_replay[1:,...,:3].shape
# dt = policy.exp_replay.dtype
# decomp = decompress(state["compressed_buffer"])
# policy.exp_replay[1:,...,:3] = np.array(np.frombuffer(decomp, dtype=dt).reshape(ts))
# if buffer_saved > 0:
# policy.retune_selector = state["retune_selector"]
if self.train_exec_impl: if self.train_exec_impl:
self.train_exec_impl.shared_metrics.get().restore( self.train_exec_impl.shared_metrics.get().restore(
......
...@@ -88,6 +88,8 @@ DEFAULT_CONFIG = with_common_config({ ...@@ -88,6 +88,8 @@ DEFAULT_CONFIG = with_common_config({
"updates_per_batch": 8, "updates_per_batch": 8,
"scale_reward": 1.0, "scale_reward": 1.0,
"return_reset": True, "return_reset": True,
"aux_phase_mixed_precision": False,
"max_time": 100000000,
}) })
# __sphinx_doc_end__ # __sphinx_doc_end__
# yapf: enable # yapf: enable
......
...@@ -8,7 +8,7 @@ procgen-ppo: ...@@ -8,7 +8,7 @@ procgen-ppo:
time_total_s: 7200 time_total_s: 7200
# === Settings for Checkpoints === # === Settings for Checkpoints ===
checkpoint_freq: 1 checkpoint_freq: 100
checkpoint_at_end: True checkpoint_at_end: True
keep_checkpoints_num: 5 keep_checkpoints_num: 5
...@@ -47,7 +47,7 @@ procgen-ppo: ...@@ -47,7 +47,7 @@ procgen-ppo:
# Custom switches # Custom switches
skips: 2 skips: 2
n_pi: 16 n_pi: 16
num_retunes: 26 num_retunes: 14
retune_epochs: 6 retune_epochs: 6
standardize_rewards: True standardize_rewards: True
aux_mbsize: 4 aux_mbsize: 4
...@@ -60,7 +60,7 @@ procgen-ppo: ...@@ -60,7 +60,7 @@ procgen-ppo:
value_lr: 1.0e-3 value_lr: 1.0e-3
same_lr_everywhere: False same_lr_everywhere: False
aux_phase_mixed_precision: True aux_phase_mixed_precision: True
single_optimizer: False single_optimizer: True
max_time: 7200 max_time: 7200
adaptive_gamma: False adaptive_gamma: False
...@@ -70,7 +70,7 @@ procgen-ppo: ...@@ -70,7 +70,7 @@ procgen-ppo:
entropy_schedule: False entropy_schedule: False
# Memory management, if batch size overflow, batch splitting is done to handle it # Memory management, if batch size overflow, batch splitting is done to handle it
max_minibatch_size: 1500 max_minibatch_size: 1000
updates_per_batch: 8 updates_per_batch: 8
normalize_actions: False normalize_actions: False
...@@ -87,8 +87,6 @@ procgen-ppo: ...@@ -87,8 +87,6 @@ procgen-ppo:
model: model:
custom_model: impala_torch_ppg custom_model: impala_torch_ppg
custom_model_config: custom_model_config:
# depths: [16, 32, 32]
# nlatents: 256
depths: [32, 64, 64] depths: [32, 64, 64]
nlatents: 512 nlatents: 512
init_normed: True init_normed: True
...@@ -96,7 +94,7 @@ procgen-ppo: ...@@ -96,7 +94,7 @@ procgen-ppo:
diff_framestack: True diff_framestack: True
num_workers: 7 num_workers: 7
num_envs_per_worker: 9 num_envs_per_worker: 16
rollout_fragment_length: 256 rollout_fragment_length: 256
......
...@@ -46,13 +46,15 @@ procgen-ppo: ...@@ -46,13 +46,15 @@ procgen-ppo:
no_done_at_end: False no_done_at_end: False
# Custom switches # Custom switches
retune_skips: 450000 retune_skips: 100000
retune_replay_size: 200000 retune_replay_size: 400000
num_retunes: 11 num_retunes: 14
retune_epochs: 3 retune_epochs: 3
standardize_rewards: True standardize_rewards: True
scale_reward: 1.0 scale_reward: 1.0
return_reset: False return_reset: False
aux_phase_mixed_precision: True
max_time: 7200
adaptive_gamma: False adaptive_gamma: False
final_lr: 5.0e-5 final_lr: 5.0e-5
...@@ -61,7 +63,7 @@ procgen-ppo: ...@@ -61,7 +63,7 @@ procgen-ppo:
entropy_schedule: False entropy_schedule: False
# Memory management, if batch size overflow, batch splitting is done to handle it # Memory management, if batch size overflow, batch splitting is done to handle it
max_minibatch_size: 2048 max_minibatch_size: 1000
updates_per_batch: 8 updates_per_batch: 8
normalize_actions: False normalize_actions: False
...@@ -77,12 +79,12 @@ procgen-ppo: ...@@ -77,12 +79,12 @@ procgen-ppo:
# === Settings for Model === # === Settings for Model ===
model: model:
custom_model: impala_torch_custom custom_model: impala_torch_custom
custom_options: custom_model_config:
depths: [32, 64, 64] depths: [32, 64, 64]
nlatents: 512 nlatents: 512
use_layernorm: True use_layernorm: True
diff_framestack: True diff_framestack: True
d2rl: True d2rl: False
num_workers: 7 num_workers: 7
num_envs_per_worker: 16 num_envs_per_worker: 16
......
...@@ -87,11 +87,11 @@ class ImpalaCNN(TorchModelV2, nn.Module): ...@@ -87,11 +87,11 @@ class ImpalaCNN(TorchModelV2, nn.Module):
nn.Module.__init__(self) nn.Module.__init__(self)
self.device = device self.device = device