utils.py 6.15 KB
Newer Older
Dipam Chakraborty's avatar
Dipam Chakraborty committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import numpy as np
from ray.rllib.utils import try_import_torch
from collections import deque
from skimage.util import view_as_windows

torch, nn = try_import_torch()

def neglogp_actions(pi_logits, actions):
    return nn.functional.cross_entropy(pi_logits, actions, reduction='none')

def sample_actions(logits, device):
    u = torch.rand(logits.shape, dtype=logits.dtype).to(device)
    return torch.argmax(logits - torch.log(-torch.log(u)), dim=1)

def pi_entropy(logits):
    a0 = logits - torch.max(logits, dim=1, keepdim=True)[0]
    ea0 = torch.exp(a0)
    z0 = torch.sum(ea0, dim=1, keepdim=True)
    p0 = ea0 / z0
    return torch.sum(p0 * (torch.log(z0) - a0), axis=1)

def roll(arr):
    s = arr.shape
    return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])

def unroll(arr, targetshape):
    s = arr.shape
    return arr.reshape(*targetshape, *s[1:]).swapaxes(0, 1)

def safe_mean(xs):
    return -np.inf if len(xs) == 0 else np.mean(xs)


def pad_and_random_crop(imgs, out, pad):
    """
    Vectorized pad and random crop
    Assumes square images?
    args:
    imgs: shape (B,H,W,C)
    out: output size (e.g. 64)
    """
    # n: batch size.
    imgs = np.pad(imgs, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
    n = imgs.shape[0]
    img_size = imgs.shape[1] # e.g. 64
    crop_max = img_size - out
    w1 = np.random.randint(0, crop_max, n)
    h1 = np.random.randint(0, crop_max, n)
    # creates all sliding window
    # combinations of size (out)
    windows = view_as_windows(imgs, (1, out, out, 1))[..., 0,:,:, 0]
    # selects a random window
    # for each batch element
    cropped = windows[np.arange(n), w1, h1]
    cropped = cropped.transpose(0,2,3,1)
    return cropped

def random_cutout_color(imgs, min_cut, max_cut):
    n, h, w, c = imgs.shape
    w1 = np.random.randint(min_cut, max_cut, n)
    h1 = np.random.randint(min_cut, max_cut, n)
    
    cutouts = np.empty((n, h, w, c), dtype=imgs.dtype)
    rand_box = np.random.randint(0, 255, size=(n, c), dtype=imgs.dtype)
    for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)):
        cut_img = img.copy()
        # add random box
        cut_img[h11:h11 + h11, w11:w11 + w11, :] = rand_box[i]
        
        cutouts[i] = cut_img
    return cutouts

def linear_schedule(initial_val, final_val, current_steps, total_steps):
    frac = 1.0 - current_steps / total_steps
    return (initial_val-final_val) * frac + final_val

def horizon_to_gamma(horizon):
    return 1.0 - 1.0/horizon
    
class AdaptiveDiscountTuner:
    def __init__(self, gamma, momentum=0.98, eplenmult=1):
        self.gamma = gamma
        self.momentum = momentum
        self.eplenmult = eplenmult
        
    def update(self, horizon):
        if horizon > 0:
            htarg = horizon * self.eplenmult
            gtarg = horizon_to_gamma(htarg)
            self.gamma = self.gamma * self.momentum + gtarg * (1-self.momentum)
        return self.gamma

class RetuneSelector:
    def __init__(self, nbatch, ob_space, ac_space, skips = 800_000, replay_size = 200_000, num_retunes = 5):
        self.skips = skips + (-skips) % nbatch
        self.replay_size = replay_size + (-replay_size) % nbatch
        self.batch_size = nbatch
        self.batches_in_replay = self.replay_size // nbatch
        
        self.num_retunes = num_retunes
        self.ac_space = ac_space
        self.ob_space = ob_space
        
        self.cooldown_counter = self.skips // self.batch_size
        self.replay_index = 0
        self.buffer_full = False

    def update(self, obs_batch, exp_replay):
        if self.num_retunes == 0:
            return False
        
        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
            return False
        
        start = self.replay_index * self.batch_size
        end = start + self.batch_size
        exp_replay[start:end] = obs_batch
        
        self.replay_index = (self.replay_index + 1) % self.batches_in_replay
        self.buffer_full = self.buffer_full or (self.replay_index == 0)
        
        return self.buffer_full
        
    def retune_done(self):
        self.cooldown_counter = self.skips // self.batch_size
        self.num_retunes -= 1
        self.replay_index = 0
        self.buffer_full = False
        
    def set_num_retunes(self, nr):
        self.num_retunes = nr

class RewardNormalizer(object):
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, gamma=0.99, cliprew=10.0, epsilon=1e-8):
        self.epsilon = epsilon
        self.gamma = gamma
        self.ret_rms = RunningMeanStd(shape=())
        self.cliprew = cliprew
        self.ret = 0. # size updates after first pass
        
    def normalize(self, rews, news, reset_returns=True):
        self.ret = self.ret * self.gamma + rews
        self.ret_rms.update(self.ret)
        rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
        if reset_returns:
            self.ret[np.array(news, dtype=bool)] = 0. ## Values should be True of False to set positional index
        return rews
    
class RunningMeanStd(object):
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count