custom_torch_policy.py 19.3 KB
Newer Older
Dipam Chakraborty's avatar
Dipam Chakraborty committed
1
from ray.rllib.policy.torch_policy import TorchPolicy
Dipam Chakraborty's avatar
Dipam Chakraborty committed
2
3
4
5
import numpy as np
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, convert_to_torch_tensor
from ray.rllib.utils import try_import_torch
from ray.rllib.models import ModelCatalog
Dipam Chakraborty's avatar
Dipam Chakraborty committed
6
from ray.rllib.utils.annotations import override
Dipam Chakraborty's avatar
Dipam Chakraborty committed
7
from collections import deque
Dipam Chakraborty's avatar
cleanup    
Dipam Chakraborty committed
8
from .utils import *
Dipam Chakraborty's avatar
Dipam Chakraborty committed
9
import time
Dipam Chakraborty's avatar
Dipam Chakraborty committed
10
11

torch, nn = try_import_torch()
Dipam Chakraborty's avatar
ppo amp    
Dipam Chakraborty committed
12
from torch.cuda.amp import autocast, GradScaler
Dipam Chakraborty's avatar
Dipam Chakraborty committed
13

Dipam Chakraborty's avatar
Dipam Chakraborty committed
14
class CustomTorchPolicy(TorchPolicy):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
15
16
17
18
19
20
21
22
    """Example of a random policy
    If you are using tensorflow/pytorch to build custom policies,
    you might find `build_tf_policy` and `build_torch_policy` to
    be useful.
    Adopted from examples from https://docs.ray.io/en/master/rllib-concepts.html
    """

    def __init__(self, observation_space, action_space, config):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
23
        self.config = config
24
25
        self.acion_space = action_space
        self.observation_space = observation_space
Dipam Chakraborty's avatar
Dipam Chakraborty committed
26

Dipam Chakraborty's avatar
cleanup    
Dipam Chakraborty committed
27
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
Dipam Chakraborty's avatar
Dipam Chakraborty committed
28
29
30
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"], framework="torch")
        self.model = ModelCatalog.get_model_v2(
Dipam Chakraborty's avatar
cleanup    
Dipam Chakraborty committed
31
32
33
34
35
36
37
                        obs_space=observation_space,
                        action_space=action_space,
                        num_outputs=logit_dim,
                        model_config=self.config["model"],
                        framework="torch",
                        device=self.device,
                     )
Dipam Chakraborty's avatar
Dipam Chakraborty committed
38
39
40
41
42
43
44
45
46
47

        TorchPolicy.__init__(
            self,
            observation_space=observation_space,
            action_space=action_space,
            config=config,
            model=self.model,
            loss=None,
            action_distribution_class=dist_class,
        )
48
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
49
        self.framework = "torch"
50
51
52

    
    def init_training(self):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
53
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
54
55
        self.max_reward = self.config['env_config']['return_max']
        self.rewnorm = RewardNormalizer(cliprew=self.max_reward) ## TODO: Might need to go to custom state
Dipam Chakraborty's avatar
Dipam Chakraborty committed
56
57
58
        self.reward_deque = deque(maxlen=100)
        self.best_reward = -np.inf
        self.best_weights = None
Dipam Chakraborty's avatar
Dipam Chakraborty committed
59
        self.time_elapsed = 0
Dipam Chakraborty's avatar
Dipam Chakraborty committed
60
        self.batch_end_time = time.time()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
61
        self.timesteps_total = 0
62
        self.best_rew_tsteps = 0
63
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
64
65
        nw = self.config['num_workers'] if self.config['num_workers'] > 0 else 1
        self.nbatch = nw * self.config['num_envs_per_worker'] * self.config['rollout_fragment_length']
66
67
68
        self.actual_batch_size = self.nbatch // self.config['updates_per_batch']
        self.accumulate_train_batches = int(np.ceil( self.actual_batch_size / self.config['max_minibatch_size'] ))
        self.mem_limited_batch_size = self.actual_batch_size // self.accumulate_train_batches
Dipam Chakraborty's avatar
Dipam Chakraborty committed
69
70
71
72
        if self.nbatch % self.actual_batch_size != 0 or self.nbatch % self.mem_limited_batch_size != 0:
            print("#################################################")
            print("WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY")
            print("#################################################")
73
        self.retune_selector = RetuneSelector(self.nbatch, self.observation_space, self.action_space, 
Dipam Chakraborty's avatar
cleanup    
Dipam Chakraborty committed
74
75
76
                                              skips = self.config['retune_skips'], 
                                              replay_size = self.config['retune_replay_size'], 
                                              num_retunes = self.config['num_retunes'])
77
        self.exp_replay = np.empty((self.retune_selector.replay_size, *self.observation_space.shape), dtype=np.uint8)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
78
        self.target_timesteps = 8_000_000
Dipam Chakraborty's avatar
Dipam Chakraborty committed
79
        self.buffer_time = 20 # TODO: Could try to do a median or mean time step check instead
Dipam Chakraborty's avatar
Dipam Chakraborty committed
80
        self.max_time = self.config['max_time']
81
        self.maxrewep_lenbuf = deque(maxlen=100)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
82
83
        self.gamma = self.config['gamma']
        self.adaptive_discount_tuner = AdaptiveDiscountTuner(self.gamma, momentum=0.98, eplenmult=3)
84
        
85
86
        self.lr = self.config['lr']
        self.ent_coef = self.config['entropy_coeff']
Dipam Chakraborty's avatar
Dipam Chakraborty committed
87
        
88
        self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],))
89
        self.save_success = 0
Dipam Chakraborty's avatar
Dipam Chakraborty committed
90
        self.retunes_completed = 0
Dipam Chakraborty's avatar
ppo amp    
Dipam Chakraborty committed
91
        self.amp_scaler = GradScaler()
92
        
Dipam Chakraborty's avatar
cleanup    
Dipam Chakraborty committed
93
94
    def to_tensor(self, arr):
        return torch.from_numpy(arr).to(self.device)
95
96
97
98
99
    
    @override(TorchPolicy)
    def extra_action_out(self, input_dict, state_batches, model, action_dist):
        return {'values': model._value.tolist()}
    
Dipam Chakraborty's avatar
Dipam Chakraborty committed
100
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
101
    @override(TorchPolicy)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
102
103
104
105
106
107
108
109
110
111
112
    def learn_on_batch(self, samples):
        """Fused compute gradients and apply gradients call.
        Either this or the combination of compute/apply grads must be
        implemented by subclasses.
        Returns:
            grad_info: dictionary of extra metadata from compute_gradients().
        Examples:
            >>> batch = ev.sample()
            >>> ev.learn_on_batch(samples)
        Reference: https://github.com/ray-project/ray/blob/master/rllib/policy/policy.py#L279-L316
        """
Dipam Chakraborty's avatar
Dipam Chakraborty committed
113
        
114
        ## Config data values
Dipam Chakraborty's avatar
Dipam Chakraborty committed
115
        nbatch = self.nbatch
116
117
118
119
120
121
122
123
124
125
126
        nbatch_train = self.mem_limited_batch_size 
        gamma, lam = self.gamma, self.config['lambda']
        nsteps = self.config['rollout_fragment_length']
        nenvs = nbatch//nsteps
        ts = (nenvs, nsteps)
        mb_dones = unroll(samples['dones'], ts)
        
        ## Reward Normalization - No reward norm works well for many envs
        if self.config['standardize_rewards']:
            mb_origrewards = unroll(samples['rewards'], ts)
            mb_rewards =  np.zeros_like(mb_origrewards)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
127
128
            mb_rewards[0] = self.rewnorm.normalize(mb_origrewards[0], self.last_dones, 
                                                   self.config["return_reset"])
129
            for ii in range(1, nsteps):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
130
131
                mb_rewards[ii] = self.rewnorm.normalize(mb_origrewards[ii], mb_dones[ii-1],
                                                       self.config["return_reset"])
132
133
134
135
            self.last_dones = mb_dones[-1]
        else:
            mb_rewards = unroll(samples['rewards'], ts)
        
136
137
138
139
140
        # Weird hack that helps in many envs (Yes keep it after reward normalization)
        rew_scale = self.config["scale_reward"]
        if rew_scale != 1.0:
            mb_rewards *= rew_scale
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
141
142
        should_skip_train_step = self.best_reward_model_select(samples)
        if should_skip_train_step:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
143
            self.update_batch_time()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
144
            return {} # Not doing last optimization step - This is intentional due to noisy gradients
145
146
          
        obs = samples['obs']
147

148
        ## Value prediction
Dipam Chakraborty's avatar
Dipam Chakraborty committed
149
        next_obs = unroll(samples['new_obs'], ts)[-1]
Dipam Chakraborty's avatar
cleanup    
Dipam Chakraborty committed
150
        last_values, _ = self.model.vf_pi(next_obs, ret_numpy=True, no_grad=True, to_torch=True)
151
        values = samples['values']
Dipam Chakraborty's avatar
Dipam Chakraborty committed
152
        
153
        ## GAE
154
        mb_values = unroll(values, ts)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
155
156
157
158
159
160
161
162
163
164
165
166
167
        mb_returns = np.zeros_like(mb_rewards)
        mb_advs = np.zeros_like(mb_rewards)
        lastgaelam = 0
        for t in reversed(range(nsteps)):
            if t == nsteps - 1:
                nextvalues = last_values
            else:
                nextvalues = mb_values[t+1]
            nextnonterminal = 1.0 - mb_dones[t]
            delta = mb_rewards[t] + gamma * nextvalues * nextnonterminal - mb_values[t]
            mb_advs[t] = lastgaelam = delta + gamma * lam * nextnonterminal * lastgaelam
        mb_returns = mb_advs + mb_values
        
168
        ## Data from config
Dipam Chakraborty's avatar
Dipam Chakraborty committed
169
        cliprange, vfcliprange = self.config['clip_param'], self.config['vf_clip_param']
170
        lrnow = self.lr
Dipam Chakraborty's avatar
Dipam Chakraborty committed
171
        max_grad_norm = self.config['grad_clip']
172
        ent_coef, vf_coef = self.ent_coef, self.config['vf_loss_coeff']
Dipam Chakraborty's avatar
Dipam Chakraborty committed
173
        
174
        neglogpacs = -samples['action_logp'] ## np.isclose seems to be True always, otherwise compute again if needed
Dipam Chakraborty's avatar
Dipam Chakraborty committed
175
        noptepochs = self.config['num_sgd_iter']
Dipam Chakraborty's avatar
Dipam Chakraborty committed
176
177
        actions = samples['actions']
        returns = roll(mb_returns)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
178
        
179
        ## Train multiple epochs
Dipam Chakraborty's avatar
Dipam Chakraborty committed
180
        optim_count = 0
Dipam Chakraborty's avatar
Dipam Chakraborty committed
181
182
183
        inds = np.arange(nbatch)
        for _ in range(noptepochs):
            np.random.shuffle(inds)
184
            normalized_advs = returns - values
185
            # Can do this because actual_batch_size is a multiple of mem_limited_batch_size
186
187
            for start in range(0, nbatch, self.actual_batch_size):
                end = start + self.actual_batch_size
188
189
190
                mbinds = inds[start:end]
                advs_batch = normalized_advs[mbinds].copy()
                normalized_advs[mbinds] = (advs_batch - np.mean(advs_batch)) / (np.std(advs_batch) + 1e-8) 
Dipam Chakraborty's avatar
Dipam Chakraborty committed
191
192
193
            for start in range(0, nbatch, nbatch_train):
                end = start + nbatch_train
                mbinds = inds[start:end]
194
                slices = (self.to_tensor(arr[mbinds]) for arr in (obs, returns, actions, values, neglogpacs, normalized_advs))
Dipam Chakraborty's avatar
Dipam Chakraborty committed
195
                optim_count += 1
196
                apply_grad = (optim_count % self.accumulate_train_batches) == 0
Dipam Chakraborty's avatar
Dipam Chakraborty committed
197
198
                self._batch_train(apply_grad, self.accumulate_train_batches,
                                  lrnow, cliprange, vfcliprange, max_grad_norm, ent_coef, vf_coef, *slices)
199
200
201
202
203
204
205
206
        
        ## Distill with augmentation
        should_retune = self.retune_selector.update(obs, self.exp_replay)
        if should_retune:
            self.retune_with_augmentation()
            self.update_batch_time()
            return {}
                
207
208
209
210
        self.update_gamma(samples)
        self.update_lr()
        self.update_ent_coef()
            
Dipam Chakraborty's avatar
Dipam Chakraborty committed
211
        self.update_batch_time()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
212
        return {}
Dipam Chakraborty's avatar
Dipam Chakraborty committed
213
    
Dipam Chakraborty's avatar
Dipam Chakraborty committed
214
215
216
217
    def update_batch_time(self):
        self.time_elapsed += time.time() - self.batch_end_time
        self.batch_end_time = time.time()
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
218
219
    def _batch_train(self, apply_grad, num_accumulate, 
                     lr, cliprange, vfcliprange, max_grad_norm,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
220
                     ent_coef, vf_coef,
221
                     obs, returns, actions, values, neglogpac_old, advs):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
222
223
224
        
        for g in self.optimizer.param_groups:
            g['lr'] = lr
Dipam Chakraborty's avatar
cleanup    
Dipam Chakraborty committed
225
        vpred, pi_logits = self.model.vf_pi(obs, ret_numpy=False, no_grad=False, to_torch=False)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
226
227
228
        neglogpac = neglogp_actions(pi_logits, actions)
        entropy = torch.mean(pi_entropy(pi_logits))

Dipam Chakraborty's avatar
Dipam Chakraborty committed
229
        vpredclipped = values + torch.clamp(vpred - values, -vfcliprange, vfcliprange)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
230
231
232
233
234
235
236
237
238
239
        vf_losses1 = torch.pow((vpred - returns), 2)
        vf_losses2 = torch.pow((vpredclipped - returns), 2)
        vf_loss = .5 * torch.mean(torch.max(vf_losses1, vf_losses2))

        ratio = torch.exp(neglogpac_old - neglogpac)
        pg_losses1 = -advs * ratio
        pg_losses2 = -advs * torch.clamp(ratio, 1-cliprange, 1+cliprange)
        pg_loss = torch.mean(torch.max(pg_losses1, pg_losses2))

        loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
240
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
241
        loss = loss / num_accumulate
Dipam Chakraborty's avatar
Dipam Chakraborty committed
242
243

        loss.backward()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
244
245
246
247
248
        if apply_grad:
            nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()

Dipam Chakraborty's avatar
Dipam Chakraborty committed
249
        
250
    def retune_with_augmentation(self):
251
        nbatch_train = self.mem_limited_batch_size 
Dipam Chakraborty's avatar
cleanup    
Dipam Chakraborty committed
252
        retune_epochs = self.config['retune_epochs']
253
254
255
256
257
258
259
        replay_size = self.retune_selector.replay_size
        replay_vf = np.empty((replay_size,), dtype=np.float32)
        replay_pi = np.empty((replay_size, self.retune_selector.ac_space.n), dtype=np.float32)

        # Store current value function and policy logits
        for start in range(0, replay_size, nbatch_train):
            end = start + nbatch_train
260
            replay_batch = self.exp_replay[start:end]
Dipam Chakraborty's avatar
cleanup    
Dipam Chakraborty committed
261
262
            replay_vf[start:end], replay_pi[start:end] = self.model.vf_pi(replay_batch, 
                                                                          ret_numpy=True, no_grad=True, to_torch=True)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
263
264
        
        optim_count = 0
265
        # Tune vf and pi heads to older predictions with augmented observations
266
        inds = np.arange(len(self.exp_replay))
267
268
269
270
271
        for ep in range(retune_epochs):
            np.random.shuffle(inds)
            for start in range(0, replay_size, nbatch_train):
                end = start + nbatch_train
                mbinds = inds[start:end]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
272
                optim_count += 1
273
                apply_grad = (optim_count % self.accumulate_train_batches) == 0
274
                slices = [self.exp_replay[mbinds], 
Dipam Chakraborty's avatar
cleanup    
Dipam Chakraborty committed
275
276
                          self.to_tensor(replay_vf[mbinds]), 
                          self.to_tensor(replay_pi[mbinds])]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
277
                self.tune_policy(apply_grad, *slices, 0.5)
278
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
279
        self.retunes_completed += 1
280
281
        self.retune_selector.retune_done()
 
Dipam Chakraborty's avatar
Dipam Chakraborty committed
282
    def tune_policy(self, apply_grad, obs, target_vf, target_pi, retune_vf_loss_coeff):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
283
        obs_aug = np.empty(obs.shape, obs.dtype)
284
        aug_idx = np.random.randint(3, size=len(obs))
Dipam Chakraborty's avatar
Dipam Chakraborty committed
285
286
        obs_aug[aug_idx == 0] = pad_and_random_crop(obs[aug_idx == 0], 64, 10)
        obs_aug[aug_idx == 1] = random_cutout_color(obs[aug_idx == 1], 10, 30)
287
        obs_aug[aug_idx == 2] = obs[aug_idx == 2]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
288
        obs_aug = self.to_tensor(obs_aug)
289
290
291
        with torch.no_grad():
            tpi_log_softmax = nn.functional.log_softmax(target_pi, dim=1)
            tpi_softmax = torch.exp(tpi_log_softmax)
Dipam Chakraborty's avatar
ppo amp    
Dipam Chakraborty committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        if not self.config['aux_phase_mixed_precision']:
            loss = self._retune_calc_loss(obs_aug, target_vf, tpi_softmax, tpi_log_softmax, retune_vf_loss_coeff)
            loss.backward()
            if apply_grad:
                self.optimizer.step()
                self.optimizer.zero_grad()
        else:
            with autocast():
                 loss = self._retune_calc_loss(obs_aug, target_vf, tpi_softmax, tpi_log_softmax, retune_vf_loss_coeff)
            self.amp_scaler.scale(loss).backward()
            
            if apply_grad:
                self.amp_scaler.step(self.optimizer)
                self.amp_scaler.update()
                self.optimizer.zero_grad()
            
    def _retune_calc_loss(self, obs_aug, target_vf, tpi_softmax, tpi_log_softmax, retune_vf_loss_coeff):
Dipam Chakraborty's avatar
cleanup    
Dipam Chakraborty committed
309
        vpred, pi_logits = self.model.vf_pi(obs_aug, ret_numpy=False, no_grad=False, to_torch=False)
310
        pi_log_softmax =  nn.functional.log_softmax(pi_logits, dim=1)
311
        pi_loss = torch.mean(torch.sum(tpi_softmax * (tpi_log_softmax - pi_log_softmax) , dim=1)) # kl_div torch 1.3.1 has numerical issues
Dipam Chakraborty's avatar
Dipam Chakraborty committed
312
        vf_loss = .5 * torch.mean(torch.pow(vpred - target_vf, 2))
313
        
314
        loss = retune_vf_loss_coeff * vf_loss + pi_loss
315
        loss = loss / self.accumulate_train_batches
Dipam Chakraborty's avatar
ppo amp    
Dipam Chakraborty committed
316
        return loss
Dipam Chakraborty's avatar
Dipam Chakraborty committed
317
318
        
    def best_reward_model_select(self, samples):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
319
        self.timesteps_total += self.nbatch
Dipam Chakraborty's avatar
Dipam Chakraborty committed
320
321
322
323
324
325
326
327
        
        ## Best reward model selection
        eprews = [info['episode']['r'] for info in samples['infos'] if 'episode' in info]
        self.reward_deque.extend(eprews)
        mean_reward = safe_mean(eprews) if len(eprews) >= 100 else safe_mean(self.reward_deque)
        if self.best_reward < mean_reward:
            self.best_reward = mean_reward
            self.best_weights = self.get_weights()["current_weights"]
328
            self.best_rew_tsteps = self.timesteps_total
Dipam Chakraborty's avatar
Dipam Chakraborty committed
329
330
331
           
        if self.timesteps_total > self.target_timesteps or (self.time_elapsed + self.buffer_time) > self.max_time:
            if self.best_weights is not None:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
332
                self.set_model_weights(self.best_weights)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
333
334
335
                return True
            
        return False
Dipam Chakraborty's avatar
Dipam Chakraborty committed
336
    
337
    def update_lr(self):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
338
        if self.config['lr_schedule'] == 'linear':
Chakraborty's avatar
Chakraborty committed
339
340
341
            self.lr = linear_schedule(initial_val=self.config['lr'],
                                      final_val=self.config['final_lr'],
                                      current_steps=self.timesteps_total,
342
                                      total_steps=self.target_timesteps)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
343
            
Dipam Chakraborty's avatar
Dipam Chakraborty committed
344
        elif self.config['lr_schedule'] == 'exponential':
Dipam Chakraborty's avatar
Dipam Chakraborty committed
345
            self.lr = 0.997 * self.lr 
Dipam Chakraborty's avatar
Dipam Chakraborty committed
346

347
348
349
350
351
352
353
354
    
    def update_ent_coef(self):
        if self.config['entropy_schedule']:
            self.ent_coef = linear_schedule(initial_val=self.config['entropy_coeff'], 
                                            final_val=self.config['final_entropy_coeff'], 
                                            current_steps=self.timesteps_total, 
                                            total_steps=self.target_timesteps)
    
Dipam Chakraborty's avatar
Dipam Chakraborty committed
355
    def update_gamma(self, samples):
356
357
358
359
360
361
362
        if self.config['adaptive_gamma']:
            epinfobuf = [info['episode'] for info in samples['infos'] if 'episode' in info]
            self.maxrewep_lenbuf.extend([epinfo['l'] for epinfo in epinfobuf if epinfo['r'] >= self.max_reward])
            sorted_nth = lambda buf, n: np.nan if len(buf) < 100 else sorted(self.maxrewep_lenbuf.copy())[n]
            target_horizon = sorted_nth(self.maxrewep_lenbuf, 80)
            self.gamma = self.adaptive_discount_tuner.update(target_horizon)

Dipam Chakraborty's avatar
Dipam Chakraborty committed
363
364
365
366
367
368
369
        
    def get_custom_state_vars(self):
        return {
            "time_elapsed": self.time_elapsed,
            "timesteps_total": self.timesteps_total,
            "best_weights": self.best_weights,
            "reward_deque": self.reward_deque,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
370
            "batch_end_time": self.batch_end_time,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
371
            "gamma": self.gamma,
372
            "maxrewep_lenbuf": self.maxrewep_lenbuf,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
373
            "num_retunes": self.retune_selector.num_retunes,
374
375
            "lr": self.lr,
            "ent_coef": self.ent_coef,
376
            "rewnorm": self.rewnorm,
377
            "best_rew_tsteps": self.best_rew_tsteps,
378
            "best_reward": self.best_reward,
379
            "last_dones": self.last_dones,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
380
            "retunes_completed": self.retunes_completed,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
381
382
383
384
385
386
387
        }
    
    def set_custom_state_vars(self, custom_state_vars):
        self.time_elapsed = custom_state_vars["time_elapsed"]
        self.timesteps_total = custom_state_vars["timesteps_total"]
        self.best_weights = custom_state_vars["best_weights"]
        self.reward_deque = custom_state_vars["reward_deque"]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
388
        self.batch_end_time = custom_state_vars["batch_end_time"]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
389
        self.gamma = self.adaptive_discount_tuner.gamma = custom_state_vars["gamma"]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
390
        self.retune_selector.set_num_retunes(custom_state_vars["num_retunes"])
391
392
        self.maxrewep_lenbuf = custom_state_vars["maxrewep_lenbuf"]
        self.lr =custom_state_vars["lr"]
393
394
        self.ent_coef = custom_state_vars["ent_coef"]
        self.rewnorm = custom_state_vars["rewnorm"]
395
        self.best_rew_tsteps = custom_state_vars["best_rew_tsteps"]
396
        self.best_reward = custom_state_vars["best_reward"]
397
        self.last_dones = custom_state_vars["last_dones"]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
398
        self.retunes_completed = custom_state_vars["retunes_completed"]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
399
400
401
    
    @override(TorchPolicy)
    def get_weights(self):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
402
403
        weights = {}
        weights["current_weights"] = {
Dipam Chakraborty's avatar
Dipam Chakraborty committed
404
405
406
            k: v.cpu().detach().numpy()
            for k, v in self.model.state_dict().items()
        }
Dipam Chakraborty's avatar
Dipam Chakraborty committed
407
408
        return weights
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
409
410
411
    
    @override(TorchPolicy)
    def set_weights(self, weights):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
412
413
        self.set_model_weights(weights["current_weights"])
        
Dipam Chakraborty's avatar
ppo amp    
Dipam Chakraborty committed
414
    def set_optimizer_state(self, optimizer_state, amp_scaler_state):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
415
416
417
        optimizer_state = convert_to_torch_tensor(optimizer_state, device=self.device)
        self.optimizer.load_state_dict(optimizer_state)
        
Dipam Chakraborty's avatar
ppo amp    
Dipam Chakraborty committed
418
419
420
        amp_scaler_state = convert_to_torch_tensor(amp_scaler_state, device=self.device)
        self.amp_scaler.load_state_dict(amp_scaler_state)
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
421
422
423
    def set_model_weights(self, model_weights):
        model_weights = convert_to_torch_tensor(model_weights, device=self.device)
        self.model.load_state_dict(model_weights)