utils.py 8.95 KB
Newer Older
Dipam Chakraborty's avatar
Dipam Chakraborty committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import numpy as np
from ray.rllib.utils import try_import_torch
from collections import deque
from skimage.util import view_as_windows

torch, nn = try_import_torch()
import torch.distributions as td
from functools import partial
import itertools

def calculate_gae_buffer(values_buffer, dones_buffer, rewards_buffer, last_values, gamma, lam):
    new_returns = np.empty_like(values_buffer)
    lastgaelam = 0
    nsegs, nsteps = values_buffer.shape[:2]
    for s in reversed(range(nsegs)):
        mb_values = values_buffer[s]
        mb_rewards = rewards_buffer[s]
        mb_dones = dones_buffer[s]
        mb_returns, _ = calculate_gae(mb_values, mb_dones, mb_rewards, 
                                      last_values, gamma, lam)
        new_returns[s] = mb_returns
        last_values = mb_values[0]
    return new_returns

        
def calculate_gae(mb_values, mb_dones, mb_rewards, last_values, gamma, lam):
    lastgaelam = 0
    nsteps = mb_values.shape[0]
    mb_advs = np.empty_like(mb_values)
    for t in reversed(range(nsteps)):
        if t == nsteps - 1:
            nextvalues = last_values
        else:
            nextvalues = mb_values[t+1]
        nextnonterminal = 1.0 - mb_dones[t]
        delta = mb_rewards[t] + gamma * nextvalues * nextnonterminal - mb_values[t]
        mb_advs[t] = lastgaelam = delta + gamma * lam * nextnonterminal * lastgaelam
    mb_returns = mb_advs + mb_values
    return mb_returns, mb_advs


def _make_categorical(x, ncat, shape):
    x = x.reshape((x.shape[0], shape, ncat))
    return td.Categorical(logits=x)

def dist_build(ac_space):
    return partial(_make_categorical, shape=1, ncat=ac_space.n)

def neglogp_actions(pi_logits, actions):
    return nn.functional.cross_entropy(pi_logits, actions, reduction='none')

def sample_actions(logits, device):
    u = torch.rand(logits.shape, dtype=logits.dtype).to(device)
    return torch.argmax(logits - torch.log(-torch.log(u)), dim=1)

def pi_entropy(logits):
    a0 = logits - torch.max(logits, dim=1, keepdim=True)[0]
    ea0 = torch.exp(a0)
    z0 = torch.sum(ea0, dim=1, keepdim=True)
    p0 = ea0 / z0
    return torch.sum(p0 * (torch.log(z0) - a0), axis=1)

def roll(arr):
    s = arr.shape
    return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])

def unroll(arr, targetshape):
    s = arr.shape
    return arr.reshape(*targetshape, *s[1:]).swapaxes(0, 1)

def safe_mean(xs):
    return -np.inf if len(xs) == 0 else np.mean(xs)


def pad_and_random_crop(imgs, out, pad):
    """
    Vectorized pad and random crop
    Assumes square images?
    args:
    imgs: shape (B,H,W,C)
    out: output size (e.g. 64)
    """
    # n: batch size.
    imgs = np.pad(imgs, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
    n = imgs.shape[0]
    img_size = imgs.shape[1] # e.g. 64
    crop_max = img_size - out
    w1 = np.random.randint(0, crop_max, n)
    h1 = np.random.randint(0, crop_max, n)
    # creates all sliding window
    # combinations of size (out)
    windows = view_as_windows(imgs, (1, out, out, 1))[..., 0,:,:, 0]
    # selects a random window
    # for each batch element
    cropped = windows[np.arange(n), w1, h1]
    cropped = cropped.transpose(0,2,3,1)
    return cropped

def random_cutout_color(imgs, min_cut, max_cut):
    n, h, w, c = imgs.shape
    w1 = np.random.randint(min_cut, max_cut, n)
    h1 = np.random.randint(min_cut, max_cut, n)
    
    cutouts = np.empty((n, h, w, c), dtype=imgs.dtype)
    rand_box = np.random.randint(0, 255, size=(n, c), dtype=imgs.dtype)
    for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)):
        cut_img = img.copy()
        # add random box
        cut_img[h11:h11 + h11, w11:w11 + w11, :] = rand_box[i]
        
        cutouts[i] = cut_img
    return cutouts

def linear_schedule(initial_val, final_val, current_steps, total_steps):
    frac = 1.0 - current_steps / total_steps
    return (initial_val-final_val) * frac + final_val

def horizon_to_gamma(horizon):
    return 1.0 - 1.0/horizon
    
class AdaptiveDiscountTuner:
    def __init__(self, gamma, momentum=0.98, eplenmult=1):
        self.gamma = gamma
        self.momentum = momentum
        self.eplenmult = eplenmult
        
    def update(self, horizon):
        if horizon > 0:
            htarg = horizon * self.eplenmult
            gtarg = horizon_to_gamma(htarg)
            self.gamma = self.gamma * self.momentum + gtarg * (1-self.momentum)
        return self.gamma
    
def flatten01(arr):
    return arr.reshape(-1, *arr.shape[2:])

def flatten012(arr):
    return arr.reshape(-1, *arr.shape[3:])

    
class RetuneSelector:
    def __init__(self, nenvs, ob_space, ac_space, replay_shape, skips = 0, n_pi = 32, num_retunes = 5, flat_buffer=False):
        self.skips = skips
        self.n_pi = n_pi
        self.nenvs = nenvs
        
        self.exp_replay = np.empty((*replay_shape, *ob_space.shape), dtype=np.uint8)
        self.dones_replay = np.empty((*replay_shape,), dtype=np.bool)
        self.rewards_replay = np.empty((*replay_shape,), dtype=np.float32)
        
        self.replay_shape = replay_shape
        
        self.num_retunes = num_retunes
        self.ac_space = ac_space
        self.ob_space = ob_space
        
        self.cooldown_counter = skips
        self.replay_index = 0
        self.flat_buffer = flat_buffer

    def update(self, obs_batch, dones_batch, rewards_batch):
        if self.num_retunes == 0:
            return False
        
        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
            return False
        
        self.exp_replay[self.replay_index] = obs_batch
        self.dones_replay[self.replay_index] = dones_batch
        self.rewards_replay[self.replay_index] = rewards_batch
        
        self.replay_index = (self.replay_index + 1) % self.n_pi
        return self.replay_index == 0
        
    def retune_done(self):
        self.cooldown_counter = self.skips
        self.num_retunes -= 1
        self.replay_index = 0
        
        
    def make_minibatches(self, presleep_pi, returns_buffer, num_rollouts):
            if not self.flat_buffer:
                env_segs = list(itertools.product(range(self.n_pi), range(self.nenvs)))
                np.random.shuffle(env_segs)
                env_segs = np.array(env_segs)
                for idx in range(0, len(env_segs), num_rollouts):
                    esinds = env_segs[idx:idx+num_rollouts]
                    mbatch = [flatten01(arr[esinds[:,0], : , esinds[:,1]]) 
                              for arr in (self.exp_replay, returns_buffer, presleep_pi)]
                    yield mbatch
            else:
                nsteps = returns_buffer.shape[1]
                buffsize = self.n_pi * nsteps * self.nenvs
                inds = np.arange(buffsize)
                np.random.shuffle(inds)
                batchsize = num_rollouts * nsteps
                for start in range(0, buffsize, batchsize):
                    end = start+batchsize
                    mbinds = inds[start:end]
                    mbatch = [flatten012(arr)[mbinds] 
                              for arr in (self.exp_replay, returns_buffer, presleep_pi)]
                    
                    yield mbatch
   
        
class RewardNormalizer(object):
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, gamma=0.99, cliprew=10.0, epsilon=1e-8):
        self.epsilon = epsilon
        self.gamma = gamma
        self.ret_rms = RunningMeanStd(shape=())
        self.cliprew = cliprew
        self.ret = 0. # size updates after first pass
        
    def normalize(self, rews, news, resetrew):
        self.ret = self.ret * self.gamma + rews
        self.ret_rms.update(self.ret)
        rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
        if resetrew:
            self.ret[np.array(news, dtype=bool)] = 0. ## Values should be True of False to set positional index
        return rews
    
class RunningMeanStd(object):
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count