utils.py 7.57 KB
Newer Older
Dipam Chakraborty's avatar
Dipam Chakraborty committed
1
2
3
4
5
6
import numpy as np
from ray.rllib.utils import try_import_torch
from collections import deque
from skimage.util import view_as_windows

torch, nn = try_import_torch()
7
8
import torch.distributions as td
from functools import partial
Dipam Chakraborty's avatar
Dipam Chakraborty committed
9
import itertools
10
11
12
13
14
15
16

def _make_categorical(x, ncat, shape):
    x = x.reshape((x.shape[0], shape, ncat))
    return td.Categorical(logits=x)

def dist_build(ac_space):
    return partial(_make_categorical, shape=1, ncat=ac_space.n)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

def neglogp_actions(pi_logits, actions):
    return nn.functional.cross_entropy(pi_logits, actions, reduction='none')

def sample_actions(logits, device):
    u = torch.rand(logits.shape, dtype=logits.dtype).to(device)
    return torch.argmax(logits - torch.log(-torch.log(u)), dim=1)

def pi_entropy(logits):
    a0 = logits - torch.max(logits, dim=1, keepdim=True)[0]
    ea0 = torch.exp(a0)
    z0 = torch.sum(ea0, dim=1, keepdim=True)
    p0 = ea0 / z0
    return torch.sum(p0 * (torch.log(z0) - a0), axis=1)

def roll(arr):
    s = arr.shape
    return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])

def unroll(arr, targetshape):
    s = arr.shape
    return arr.reshape(*targetshape, *s[1:]).swapaxes(0, 1)

def safe_mean(xs):
    return -np.inf if len(xs) == 0 else np.mean(xs)


def pad_and_random_crop(imgs, out, pad):
    """
    Vectorized pad and random crop
    Assumes square images?
    args:
    imgs: shape (B,H,W,C)
    out: output size (e.g. 64)
    """
    # n: batch size.
    imgs = np.pad(imgs, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
    n = imgs.shape[0]
    img_size = imgs.shape[1] # e.g. 64
    crop_max = img_size - out
    w1 = np.random.randint(0, crop_max, n)
    h1 = np.random.randint(0, crop_max, n)
    # creates all sliding window
    # combinations of size (out)
    windows = view_as_windows(imgs, (1, out, out, 1))[..., 0,:,:, 0]
    # selects a random window
    # for each batch element
    cropped = windows[np.arange(n), w1, h1]
    cropped = cropped.transpose(0,2,3,1)
    return cropped

def random_cutout_color(imgs, min_cut, max_cut):
    n, h, w, c = imgs.shape
    w1 = np.random.randint(min_cut, max_cut, n)
    h1 = np.random.randint(min_cut, max_cut, n)
    
    cutouts = np.empty((n, h, w, c), dtype=imgs.dtype)
    rand_box = np.random.randint(0, 255, size=(n, c), dtype=imgs.dtype)
    for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)):
        cut_img = img.copy()
        # add random box
        cut_img[h11:h11 + h11, w11:w11 + w11, :] = rand_box[i]
        
        cutouts[i] = cut_img
    return cutouts

def linear_schedule(initial_val, final_val, current_steps, total_steps):
    frac = 1.0 - current_steps / total_steps
    return (initial_val-final_val) * frac + final_val

def horizon_to_gamma(horizon):
    return 1.0 - 1.0/horizon
    
class AdaptiveDiscountTuner:
    def __init__(self, gamma, momentum=0.98, eplenmult=1):
        self.gamma = gamma
        self.momentum = momentum
        self.eplenmult = eplenmult
        
    def update(self, horizon):
        if horizon > 0:
            htarg = horizon * self.eplenmult
            gtarg = horizon_to_gamma(htarg)
            self.gamma = self.gamma * self.momentum + gtarg * (1-self.momentum)
        return self.gamma
Dipam Chakraborty's avatar
Dipam Chakraborty committed
102
103
104
    
def flatten01(arr):
    return arr.reshape(-1, *arr.shape[2:])
Dipam Chakraborty's avatar
Dipam Chakraborty committed
105
106
107
108

def flatten012(arr):
    return arr.reshape(-1, *arr.shape[3:])

Dipam Chakraborty's avatar
Dipam Chakraborty committed
109
    
Dipam Chakraborty's avatar
Dipam Chakraborty committed
110
class RetuneSelector:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
111
    def __init__(self, nenvs, ob_space, ac_space, replay_shape, skips = 0, n_pi = 32, num_retunes = 5, flat_buffer=False):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
112
113
114
        self.skips = skips
        self.n_pi = n_pi
        self.nenvs = nenvs
Dipam Chakraborty's avatar
Dipam Chakraborty committed
115
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
116
        self.exp_replay = np.empty((*replay_shape, *ob_space.shape), dtype=np.uint8)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
117
118
        self.vtarg_replay = np.empty(replay_shape, dtype=np.float32)
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
119
120
121
122
        self.num_retunes = num_retunes
        self.ac_space = ac_space
        self.ob_space = ob_space
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
123
        self.cooldown_counter = skips
Dipam Chakraborty's avatar
Dipam Chakraborty committed
124
        self.replay_index = 0
Dipam Chakraborty's avatar
Dipam Chakraborty committed
125
        self.flat_buffer = flat_buffer
Dipam Chakraborty's avatar
Dipam Chakraborty committed
126

Dipam Chakraborty's avatar
Dipam Chakraborty committed
127
    def update(self, obs_batch, vtarg_batch):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
128
129
130
131
132
133
134
        if self.num_retunes == 0:
            return False
        
        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
            return False
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
135
136
        self.exp_replay[self.replay_index] = obs_batch
        self.vtarg_replay[self.replay_index] = vtarg_batch
Dipam Chakraborty's avatar
Dipam Chakraborty committed
137
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
138
139
        self.replay_index = (self.replay_index + 1) % self.n_pi
        return self.replay_index == 0
Dipam Chakraborty's avatar
Dipam Chakraborty committed
140
141
        
    def retune_done(self):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
142
        self.cooldown_counter = self.skips
Dipam Chakraborty's avatar
Dipam Chakraborty committed
143
144
145
        self.num_retunes -= 1
        self.replay_index = 0
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
146
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
147
    def make_minibatches(self, presleep_pi, num_rollouts):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
148
149
150
151
152
153
154
            if not self.flat_buffer:
                env_segs = list(itertools.product(range(self.n_pi), range(self.nenvs)))
                np.random.shuffle(env_segs)
                env_segs = np.array(env_segs)
                for idx in range(0, len(env_segs), num_rollouts):
                    esinds = env_segs[idx:idx+num_rollouts]
                    mbatch = [flatten01(arr[esinds[:,0], : , esinds[:,1]]) 
Dipam Chakraborty's avatar
Dipam Chakraborty committed
155
                              for arr in (self.exp_replay, self.vtarg_replay, presleep_pi)]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
156
                    yield mbatch
Dipam Chakraborty's avatar
Dipam Chakraborty committed
157
            else:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
158
                nsteps = self.vtarg_replay.shape[1]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
159
160
161
162
163
164
165
166
                buffsize = self.n_pi * nsteps * self.nenvs
                inds = np.arange(buffsize)
                np.random.shuffle(inds)
                batchsize = num_rollouts * nsteps
                for start in range(0, buffsize, batchsize):
                    end = start+batchsize
                    mbinds = inds[start:end]
                    mbatch = [flatten012(arr)[mbinds] 
Dipam Chakraborty's avatar
Dipam Chakraborty committed
167
                              for arr in (self.exp_replay, self.vtarg_replay, presleep_pi)]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
168
                    
Dipam Chakraborty's avatar
Dipam Chakraborty committed
169
                    yield mbatch
Dipam Chakraborty's avatar
Dipam Chakraborty committed
170
171
   
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
172
173
174
175
176
177
178
179
180
class RewardNormalizer(object):
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, gamma=0.99, cliprew=10.0, epsilon=1e-8):
        self.epsilon = epsilon
        self.gamma = gamma
        self.ret_rms = RunningMeanStd(shape=())
        self.cliprew = cliprew
        self.ret = 0. # size updates after first pass
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
181
    def normalize(self, rews, news, resetrew):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
182
183
        self.ret = self.ret * self.gamma + rews
        self.ret_rms.update(self.ret)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
184
        rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
185
186
        if resetrew:
            self.ret[np.array(news, dtype=bool)] = 0. ## Values should be True of False to set positional index
Dipam Chakraborty's avatar
Dipam Chakraborty committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        return rews
    
class RunningMeanStd(object):
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count