run.py 15.2 KB
Newer Older
mohanty's avatar
mohanty committed
1
2
#!/usr/bin/env python

mohanty's avatar
mohanty committed
3
import os
mohanty's avatar
mohanty committed
4
5
import numpy as np
from tqdm.auto import tqdm
Dipam Chakraborty's avatar
Dipam Chakraborty committed
6
import torch
Dipam Chakraborty's avatar
Dipam Chakraborty committed
7
from torchvision import transforms as T
mohanty's avatar
mohanty committed
8

Shivam Khandelwal's avatar
Shivam Khandelwal committed
9
from evaluator.dataset import ZEWDPCBaseDataset, ZEWDPCProtectedDataset
mohanty's avatar
mohanty committed
10

Dipam Chakraborty's avatar
Dipam Chakraborty committed
11
12
13
from baseline_utils.model import ResnetPredictor
from baseline_utils.predict import predict_on_dataset
from baseline_utils.training import train_on_dataset
Dipam Chakraborty's avatar
Dipam Chakraborty committed
14
from baseline_utils.dataset import SimpleDataset
Dipam Chakraborty's avatar
Dipam Chakraborty committed
15

Dipam Chakraborty's avatar
Dipam Chakraborty committed
16
from purchase_strategies.random_purchase import random_purchase
17
from purchase_strategies.morefaults_purchase import purchase_data_with_more_faults
Dipam Chakraborty's avatar
Dipam Chakraborty committed
18
19
from purchase_strategies.purchase_uncertain import purchase_uncertain_images
from purchase_strategies.balance_labels import match_labels_to_target_dist
Dipam Chakraborty's avatar
Dipam Chakraborty committed
20
21

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
Dipam Chakraborty's avatar
Dipam Chakraborty committed
22
23
24
25
26

class Hparams:
    NUM_CLASSES = 6
    USE_PRETRAINED = True
    FEATURE_EXTRACTING = False
Dipam Chakraborty's avatar
Dipam Chakraborty committed
27
    NUM_EPOCHS = 4
Dipam Chakraborty's avatar
Dipam Chakraborty committed
28
29
    NUM_EPOCHS_PRETRAIN = 6
    BATCH_SIZE = 64
Dipam Chakraborty's avatar
Dipam Chakraborty committed
30
31
32
    VALIDATION_PERCENTAGE = 0.1
    VALIDATION_INTERVAL = 2
    DEVICE = 'cuda'
mohanty's avatar
mohanty committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46

class ZEWDPCBaseRun:
    """
    Template Submission Class for the ZEW Data Purchasing Challenge 2022.

    The submission template follows the following hooks :
        - pre_training_phase
        - purchase_phase
        - prediction_phase
        - save_checkpoint
        - load_checkpoint

        Please refer to the inline documentation for further details.
        You are allowed to add any other member functions, however you
Shivam Khandelwal's avatar
Shivam Khandelwal committed
47
        are not allowed to change the names of these hooks, otherwise your
mohanty's avatar
mohanty committed
48
49
50
51
        submissions will not be evaluated by the automated evaluators.
    """

    def __init__(self):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
52

53
        # self._seed(42)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
54

mohanty's avatar
mohanty committed
55
56
        self.evaluation_state = {}

Dipam Chakraborty's avatar
Dipam Chakraborty committed
57
58
59
        self.model = ResnetPredictor(use_pretrained=Hparams.USE_PRETRAINED, 
                                     feature_extracting=Hparams.FEATURE_EXTRACTING, 
                                     num_classes=Hparams.NUM_CLASSES)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
60
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
61

Dipam Chakraborty's avatar
Dipam Chakraborty committed
62
    def _seed(self, seed):       
Dipam Chakraborty's avatar
Dipam Chakraborty committed
63
64
65
66
        self.seed = seed
        torch.manual_seed(seed)
        torch.use_deterministic_algorithms(True)

mohanty's avatar
mohanty committed
67
    def pre_training_phase(
mohanty's avatar
mohanty committed
68
69
        self,
        training_dataset: ZEWDPCBaseDataset,
70
        compute_budget=10**10,
mohanty's avatar
mohanty committed
71
        register_progress=lambda x: False,
mohanty's avatar
mohanty committed
72
73
74
75
    ):
        """
        # Pre-training Phase
        -------------------------
Shivam Khandelwal's avatar
Shivam Khandelwal committed
76
        Pre-train your model on the available training dataset here.
mohanty's avatar
mohanty committed
77
78
        Hook for the Pre-Training Phase of the Competition, where you
        have access to a training_dataset, which is an instance of the
Shivam Khandelwal's avatar
Shivam Khandelwal committed
79
        `ZEWDPCBaseDataset` class (see `evaluator/dataset.py` for more details).
mohanty's avatar
mohanty committed
80
81
82
83
84
85
86
87
88

        You are allowed to pre-train on this data, while you prepare
        for the Purchase_Phase of the competition.

        If you train some models, you can instantiate them as `self.model`,
        as long as you implement self-contained checkpointing in the
        `self.save_checkpoint` and `self.load_checkpoint` hooks, as the
        hooks for the different phases of the competition, can be called
        in different executions of the BaseRun.
Shivam Khandelwal's avatar
Shivam Khandelwal committed
89

90
        The `compute_budget` argument holds a floating point number representing
mohanty's avatar
mohanty committed
91
92
93
94
        the time available (in seconds) for **BOTH** the pre_training_phase and
        the `purchase_phase`.
        Exceeding the time will lead to a TimeOut error.

Dipam Chakraborty's avatar
Dipam Chakraborty committed
95
96
97
98
99
100
        You have access to a `register_progress` function, to which you can
        pass a value between [0,1] to relay onto the leaderboard your self reported
        progress on the training phase to be displayed on the submission dashboard
        during the evaluation. If a value out of these bounds is provided, it will
        be clipped to this range.

mohanty's avatar
mohanty committed
101
102
103
        """
        print("\n================> Pre-Training Phase\n")

Dipam Chakraborty's avatar
Dipam Chakraborty committed
104
        criterion = torch.nn.BCEWithLogitsLoss()
105
106
107
108
109
110
111
112
113
114
115
116
        self.model = train_on_dataset(
                        self.model,
                        training_dataset,
                        Hparams.NUM_EPOCHS_PRETRAIN,
                        Hparams.BATCH_SIZE,
                        Hparams.VALIDATION_PERCENTAGE,
                        Hparams.VALIDATION_INTERVAL,
                        Hparams.DEVICE,
                        criterion=criterion,
                        optimizer=self.optimizer,
                        register_progress_fn=register_progress,  # [Optional, but recommended] Mark Progress
                    )
mohanty's avatar
mohanty committed
117
118
119
120
121
122
123

        print("Execution Complete of Training Phase.")

    def purchase_phase(
        self,
        unlabelled_dataset: ZEWDPCProtectedDataset,
        training_dataset: ZEWDPCBaseDataset,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
124
        purchase_budget=1500,
125
        compute_budget=10**10,
mohanty's avatar
mohanty committed
126
127
128
129
130
131
132
133
        register_progress=lambda x: False,
    ):
        """
        # Purchase Phase
        -------------------------
        In this phase of the competition, you have access to
        the unlabelled_dataset (an instance of `ZEWDPCProtectedDataset`)
        and the training_dataset (an instance of `ZEWDPCBaseDataset`)
134
        {see datasets.py for more details}, a purchase budget, and a compute budget.
mohanty's avatar
mohanty committed
135
136
137
138

        You can iterate over both the datasets and access the images without restrictions.
        However, you can probe the labels of the unlabelled_dataset only until you
        run out of the label purchasing budget.
Shivam Khandelwal's avatar
Shivam Khandelwal committed
139

140
        The `compute_budget` argument holds a floating point number representing
mohanty's avatar
mohanty committed
141
142
143
144
        the time available (in seconds) for **BOTH** the pre_training_phase and
        the `purchase_phase`.
        Exceeding the time will lead to a TimeOut error.

mohanty's avatar
mohanty committed
145
        """
146
        print("\n================> Purchase Phase | Budget = {}\n".format(purchase_budget))
mohanty's avatar
mohanty committed
147

148
        register_progress(0.0)  # Register Progress
mohanty's avatar
mohanty committed
149

150
151
152
        ##### Sample a small amount and train further #####
        random_sample_budget = purchase_budget*2//10 # 20%
        images, labels = random_purchase(unlabelled_dataset, random_sample_budget)
mohanty's avatar
mohanty committed
153

154
        total_unlabelled_images = images
Dipam Chakraborty's avatar
Dipam Chakraborty committed
155
        total_unlabelled_labels = labels
Dipam Chakraborty's avatar
Dipam Chakraborty committed
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        #######################################################################################################################
        ## Further train on the combined data
        total_unlabelled_dataset = SimpleDataset(total_unlabelled_images, total_unlabelled_labels)
        combined_dataset = torch.utils.data.ConcatDataset([training_dataset, total_unlabelled_dataset])

        criterion = torch.nn.BCEWithLogitsLoss()
        self.model = train_on_dataset(
                        self.model,
                        combined_dataset,
                        Hparams.NUM_EPOCHS,
                        Hparams.BATCH_SIZE,
                        Hparams.VALIDATION_PERCENTAGE,
                        Hparams.VALIDATION_INTERVAL,
                        Hparams.DEVICE,
                        criterion=criterion,
                        optimizer=self.optimizer,
                    )

        register_progress(len(unlabelled_dataset.purchases)/purchase_budget)

        # Predict on all images
        precicted_labels = predict_on_dataset(self.model, unlabelled_dataset,
                                              Hparams.BATCH_SIZE, Hparams.DEVICE)

        # Remove already purchased images from prediction list
        for label_idx in total_unlabelled_labels:
            precicted_labels.pop(label_idx)

        #######################################################################################################################
186
187
188
        ##### Purchase images with more faults #####
        morefaults_budget = purchase_budget*3//10 # 30%
        images, labels = purchase_data_with_more_faults(unlabelled_dataset,
189
                                                        precicted_labels,
190
                                                        morefaults_budget)
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

        total_unlabelled_images.update(images)
        total_unlabelled_labels.update(labels)

        # Remove already purchased images from prediction list
        for label_idx in labels:
            precicted_labels.pop(label_idx)

        register_progress(len(unlabelled_dataset.purchases)/purchase_budget)

        ######################################################################################################################
        #### Purchase uncertain images #####
        uncertain_budget = purchase_budget*3//10 # 30%
        images, labels = purchase_uncertain_images(unlabelled_dataset,
                                                   precicted_labels,
                                                   uncertain_budget)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
207
        
208
209
        total_unlabelled_images.update(images)
        total_unlabelled_labels.update(labels)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
210

211
212
213
        # Remove already purchased images from prediction list
        for label_idx in labels:
            precicted_labels.pop(label_idx)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
214
        
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        register_progress(len(unlabelled_dataset.purchases)/purchase_budget)

        ######################################################################################################################
        # Further train on the combined data
        total_unlabelled_dataset = SimpleDataset(total_unlabelled_images, total_unlabelled_labels)
        combined_dataset = torch.utils.data.ConcatDataset([training_dataset, total_unlabelled_dataset])

        criterion = torch.nn.BCEWithLogitsLoss()
        self.model = train_on_dataset(
                        self.model,
                        combined_dataset,
                        Hparams.NUM_EPOCHS,
                        Hparams.BATCH_SIZE,
                        Hparams.VALIDATION_PERCENTAGE,
                        Hparams.VALIDATION_INTERVAL,
                        Hparams.DEVICE,
                        criterion=criterion,
                        optimizer=self.optimizer,
                    )

        # Predict on all images
        precicted_labels = predict_on_dataset(self.model, unlabelled_dataset, 
                                              Hparams.BATCH_SIZE, Hparams.DEVICE)

        # Remove already purchased images from prediction list
        for label_idx in total_unlabelled_labels:
            precicted_labels.pop(label_idx)

        #######################################################################################################################
        ##### Balance the dataset labels with the rest of the purchase balance ######
        rebalance_budget = purchase_budget - len(unlabelled_dataset.purchases)
        target_distribution = [0.166, 0.166, 0.166, 0.166, 0.166, 0.17]
        assert sum(target_distribution) == 1
        images, labels = match_labels_to_target_dist(unlabelled_dataset,
                                                     total_unlabelled_labels,
                                                     precicted_labels,
                                                     target_distribution,
                                                     rebalance_budget)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
253
        
254
255
        total_unlabelled_images.update(images)
        total_unlabelled_labels.update(labels)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
256

257
        register_progress(len(unlabelled_dataset.purchases)/purchase_budget)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
258

mohanty's avatar
mohanty committed
259
        print("Execution Complete of Purchase Phase.")
Dipam Chakraborty's avatar
Dipam Chakraborty committed
260

261
262
263
        # Participants DO NOT need to return anything in the purchase phase
        # Their indexes used on unlabelled_dataset.purchase_label(idx) will be registered by the evaluator
        # These indexes will be used for the respective purchased labels
mohanty's avatar
mohanty committed
264

Dipam Chakraborty's avatar
Dipam Chakraborty committed
265
        return total_unlabelled_labels
Dipam Chakraborty's avatar
Dipam Chakraborty committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    
    def prediction_phase(
        self,
        test_dataset: ZEWDPCBaseDataset,
        register_progress=lambda x: False,
    ):
        """
        # Prediction Phase
        -------------------------
        In this phase of the competition, you have access to the test dataset, and you
        are supposed to make predictions using your trained models.

        Returns:
            np.ndarray of shape (n, 6)
                where n is the number of samples in the test set
                and 6 refers to the 6 labels to be predicted for each sample
                for the multi-label classification problem.

        PARTICIPANT_TODO: Add your code here
        """
        print(
            "\n================> Prediction Phase : - on {} images\n".format(
                len(test_dataset)
            )
        )

        predictions = []
Dipam Chakraborty's avatar
Dipam Chakraborty committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        batch_size = 16

        self.model.eval()

        transform = T.Compose(
            [
                T.ToTensor(),
                *self.model.required_transforms,
            ]
        )
        test_dataset.set_transform(transform)
        dataloader = torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False
        )
        outputs = []

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)

        for data in tqdm(dataloader, total=len(dataloader)):
            with torch.no_grad():
                image = data["image"].to(device)
                output = self.model(image)
                output_with_activation = torch.argmax(output, dim=1).cpu().numpy()
                outputs.append(output_with_activation)


        outputs = np.concatenate(outputs, axis=0)
        outputs = outputs > 0.5
Dipam Chakraborty's avatar
Dipam Chakraborty committed
322
323
324
325
326

        register_progress(1.0)

        predictions = np.array(predictions)  # random predictions
        print("Execution Complete of Purchase Phase.")
Dipam Chakraborty's avatar
Dipam Chakraborty committed
327
        return outputs
Dipam Chakraborty's avatar
Dipam Chakraborty committed
328

mohanty's avatar
mohanty committed
329

mohanty's avatar
mohanty committed
330
    def save_checkpoint(self, checkpoint_folder):
mohanty's avatar
mohanty committed
331
332
333
        """
        Self-contained checkpoint code to be included here,
        which can capture the state of your run (including any trained models, etc)
mohanty's avatar
mohanty committed
334
        at the provided folder path.
mohanty's avatar
mohanty committed
335
336
337

        This is critical to implement, as the execution of the different phases can
        happen using different instances of the BaseRun. See below for examples.
Shivam Khandelwal's avatar
Shivam Khandelwal committed
338
339

        PARTICIPANT_TODO: Add your code here
mohanty's avatar
mohanty committed
340
        """
mohanty's avatar
mohanty committed
341
        checkpoint_path = os.path.join(checkpoint_folder, "model.pth")
Dipam Chakraborty's avatar
Dipam Chakraborty committed
342
343
344
        torch.save({    "model": self.model.state_dict(), 
                        "optimizer": self.optimizer.state_dict()
                    }, checkpoint_path)
mohanty's avatar
mohanty committed
345

mohanty's avatar
mohanty committed
346
    def load_checkpoint(self, checkpoint_folder):
mohanty's avatar
mohanty committed
347
348
349
        """
        Self-contained checkpoint code to be included here,
        which can load the state of your run (including any trained models, etc)
mohanty's avatar
mohanty committed
350
351
        from a provided checkpoint_folder path 
        (previously saved using `self.save_checkpoint`)
mohanty's avatar
mohanty committed
352
353
354

        This is critical to implement, as the execution of the different phases can
        happen using different instances of the BaseRun. See below for examples.
Shivam Khandelwal's avatar
Shivam Khandelwal committed
355
356

        PARTICIPANT_TODO: Add your code here
mohanty's avatar
mohanty committed
357
        """
mohanty's avatar
mohanty committed
358
        checkpoint_path = os.path.join(checkpoint_folder, "model.pth")
Dipam Chakraborty's avatar
Dipam Chakraborty committed
359
360
        self.model = ResnetPredictor(use_pretrained=False,
                                     feature_extracting=Hparams.FEATURE_EXTRACTING)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
361
362
363
        load_dict = torch.load(checkpoint_path)
        self.model.load_state_dict(load_dict["model"])
        self.optimizer.load_state_dict(load_dict["optimizer"])
Dipam Chakraborty's avatar
Dipam Chakraborty committed
364

mohanty's avatar
mohanty committed
365
366
367
368


if __name__ == "__main__":
    ####################################################################################
Shivam Khandelwal's avatar
Shivam Khandelwal committed
369
370
371
    ## You need to implement `ZEWDPCBaseRun` class in this file for this challenge.
    ## Code for running all the phases locally is written in `main.py` for illustration
    ## purposes.
mohanty's avatar
mohanty committed
372
    ##
Shivam Khandelwal's avatar
Shivam Khandelwal committed
373
    ## Checkout the inline documentation of `ZEWDPCBaseRun` for more details.
mohanty's avatar
mohanty committed
374
    ####################################################################################
mohanty's avatar
mohanty committed
375
    import local_evaluation