Skip to content
Snippets Groups Projects
Commit 67d3c8b5 authored by mohanty's avatar mohanty
Browse files

Add working example of pytorch training

parent 7cc42435
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
import numpy as np
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from evaluator.dataset import ZEWDPCBaseDataset, ZEWDPCProtectedDataset
class RuntimeDataset(torch.utils.data.Dataset):
"""
A Dataset class to hold processed images (and labels) during runtime
"""
def __init__(self, images, labels=False):
self.images = images
self.labels = labels
self.transform = False
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
_sample = {}
_sample["image"] = self.images[idx]
# Apply Transform if present
if self.transform:
_sample["image"] = self.transform(_sample["image"])
# Include Label if present
if self.labels:
_sample["label"] = self.labels[idx]
return _sample
def set_transform(self, transform):
self.transform = transform
RESIZE_DIM = 32
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(RESIZE_DIM**2, 4)
def forward(self, x):
x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
x = self.fc1(x)
x = torch.sigmoid(x)
return x
class ZEWDPCBaseModelRun:
"""
Template Submission Class for the ZEW Data Purchasing Challenge 2022.
The submission template follows the following hooks :
- pre_training_phase
- purchase_phase
- prediction_phase
- save_checkpoint
- load_checkpoint
Please refer to the inline documentation for further details.
You are allowed to add any other member functions, however you
are not allowed to change the names of these hooks, else your
submissions will not be evaluated by the automated evaluators.
"""
def __init__(self):
self.evaluation_state = {}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._init_model()
self._setup_transforms()
def _init_model(self):
self.model = Net()
def _setup_transforms(self):
# NOTE: This is a toy example for demonstration purposes.
self.image_preprocess = transforms.Compose(
[
transforms.ToTensor(), # Convert to Tensor
transforms.Resize((RESIZE_DIM, RESIZE_DIM)), # Resize to 32x32px
transforms.Grayscale(num_output_channels=1), # Convert to Grayscale
transforms.Normalize(mean=[0.5], std=[0.5]), # Normalize
]
)
def _train_one_epoch(self, model, dataloader, optimizer, criterion, epoch_number=0):
"""
Helper function to train a model on a specific dataset
for a single epoch.
"""
model.train()
training_running_loss = 0.0
for i, data_batch in tqdm(
enumerate(dataloader),
total=len(dataloader),
desc="Epoch #{}".format(epoch_number),
leave=False,
):
data = data_batch["image"].to(self.device)
target = torch.stack(data_batch["label"]).float().to(self.device)
optimizer.zero_grad()
outputs = model(data) # Outputs should already be treated by sigmoid
loss = criterion(outputs, target)
training_running_loss += loss.item()
# backpropagation
loss.backward()
# update optimizer parameters
optimizer.step()
train_loss = training_running_loss / len(dataloader)
return train_loss
def _train_N_epochs(
self,
model,
dataloader,
optimizer,
criterion,
max_epochs=5,
register_progress=lambda x: False,
description=False
):
loss = []
for epoch_idx in tqdm(range(max_epochs), desc=description):
epoch_loss = self._train_one_epoch(
model,
dataloader,
optimizer,
criterion,
epoch_number=epoch_idx + 1,
)
loss.append(epoch_loss)
# Compute and Register Progress
progress = float(epoch_idx + 1) / max_epochs
register_progress(progress)
def pre_training_phase(
self, training_dataset: ZEWDPCBaseDataset, register_progress=lambda x: False
):
"""
# Pre-training Phase
-------------------------
Pretrain your model on the available labelled dataset here
Hook for the Pre-Training Phase of the Competition, where you
have access to a training_dataset, which is an instance of the
`ZEWDPCBaseDataset` class (see dataset.py for more details).
You are allowed to pre-train on this data, while you prepare
for the Purchase_Phase of the competition.
If you train some models, you can instantiate them as `self.model`,
as long as you implement self-contained checkpointing in the
`self.save_checkpoint` and `self.load_checkpoint` hooks, as the
hooks for the different phases of the competition, can be called
in different executions of the BaseRun.
"""
print("\n================> Pre-Training Phase\n")
lr = 0.0001
MAX_EPOCHS = 5
batch_size = 4
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
criterion = nn.BCELoss()
training_dataset.set_transform(self.image_preprocess)
training_dataloader = torch.utils.data.DataLoader(
training_dataset, batch_size=batch_size, shuffle=True, num_workers=0
)
self._train_N_epochs(
self.model,
training_dataloader,
optimizer,
criterion,
max_epochs=5,
register_progress=register_progress,
description="Pre Training Phase"
)
print("Execution Complete of Training Phase.")
def purchase_phase(
self,
unlabelled_dataset: ZEWDPCProtectedDataset,
training_dataset: ZEWDPCBaseDataset,
budget=1000,
register_progress=lambda x: False,
):
"""
# Purchase Phase
-------------------------
In this phase of the competition, you have access to
the unlabelled_dataset (an instance of `ZEWDPCProtectedDataset`)
and the training_dataset (an instance of `ZEWDPCBaseDataset`)
{see datasets.py for more details}, and a purchase budget.
You can iterate over both the datasets and access the images without restrictions.
However, you can probe the labels of the unlabelled_dataset only until you
run out of the label purchasing budget.
"""
print("\n================> Purchase Phase | Budget = {}\n".format(budget))
register_progress(0.0) # Register Progress
purchased_images = []
purchased_labels = []
############################################################
############################################################
#
# Purchase Labels
############################################################
for sample in tqdm(unlabelled_dataset, desc="Label Purchase"):
idx = sample["idx"]
image = sample["image"]
# Budgeting & Purchasing Labels
if budget > 0:
label = unlabelled_dataset.purchase_label(idx)
purchased_images.append(image)
purchased_labels.append(label)
budget -= 1
############################################################
############################################################
#
# Train on Purchased Labels
############################################################
current_progress = 0.5
register_progress(0.5) # Register Progress - Mark as 50% complete
def register_partial_progress(sub_phase_progress=0.0):
"""
Helper function to translate progress of sub phase to that of
the overall phase
sub_phase_progress will be a progress indicator in [0,1] representing
the progress in the subphase
"""
overall_phase_progress = (
current_progress + (1 - current_progress) * sub_phase_progress
)
register_progress(overall_phase_progress)
# Instantiate Params
lr = 0.0001
MAX_EPOCHS = 5
batch_size = 4
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
criterion = nn.BCELoss()
# Instantiate Dataset & Dataloaders
purchased_dataset = RuntimeDataset(
images=purchased_images, labels=purchased_labels
)
purchased_dataset.set_transform(self.image_preprocess)
dataloader = torch.utils.data.DataLoader(
purchased_dataset, batch_size=batch_size, shuffle=True, num_workers=0
)
# Train for N epochs
self._train_N_epochs(
self.model,
dataloader,
optimizer,
criterion,
max_epochs=MAX_EPOCHS,
register_progress=register_partial_progress,
description="Post Purchase Training"
)
print("Execution Complete of Purchase Phase.")
def prediction_phase(
self,
test_dataset: ZEWDPCBaseDataset,
register_progress=lambda x: False,
):
"""
# Prediction Phase
-------------------------
In this phase of the competition, you have access to a test set, and you
are supposed to make predictions using your trained models.
Returns:
np.ndarray of shape (n, 4)
where n is the number of samples in the test set
and 4 refers to the 4 labels to be predicted for each sample
for the multi-label classification problem.
"""
print(
"\n================> Prediction Phase : - on {} images\n".format(
len(test_dataset)
)
)
test_dataset.set_transform(self.image_preprocess)
# Setup Test Runtime Params
inference_batch_size = 4
# Setup Inference Dataloader
inference_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=inference_batch_size,
shuffle=False, # IMP: Shuffle should be False here to ensure predictions are aligned with the test samples
num_workers=0,
)
predictions = []
for _idx, data_batch in tqdm(
enumerate(inference_dataloader),
total=len(inference_dataloader),
desc="Prediction Phase",
):
data = data_batch["image"].to(self.device)
_prediction = self.model(data)
threshold = 0.5
_prediction = (_prediction > threshold) * 1.0
predictions += _prediction.tolist()
# Mark Progress
inference_progress = (_idx + 1) / len(inference_dataloader)
register_progress(inference_progress)
predictions = np.array(predictions)
print("Execution Complete of Purchase Phase.")
return predictions
def save_checkpoint(self, checkpoint_path):
"""
Self-contained checkpoint code to be included here,
which can capture the state of your run (including any trained models, etc)
at the provided path.
This is critical to implement, as the execution of the different phases can
happen using different instances of the BaseRun. See below for examples.
"""
## Save Model
torch.save(self.model.state_dict(), checkpoint_path)
def load_checkpoint(self, checkpoint_path):
"""
Self-contained checkpoint code to be included here,
which can load the state of your run (including any trained models, etc)
from a provided path (previously saved using `self.save_checkpoint`)
This is critical to implement, as the execution of the different phases can
happen using different instances of the BaseRun. See below for examples.
"""
## Load Model
# Assumes that self._init_model() has already been called before
self.model.load_state_dict(torch.load(checkpoint_path))
if __name__ == "__main__":
####################################################################################
## You need to implement `ZEWDPCBaseModelRun` class in this file for this challenge.
## Code for running all the phases locally is written in `main.py` for illustration
## purposes.
##
## Checkout the inline documentation of `ZEWDPCBaseModelRun` for more details.
####################################################################################
####################################################################################
####################################################################################
##
## Dataset Initialization
####################################################################################
DATASET_SHUFFLE_SEED = 1022022
# Instantiate Training Dataset
training_dataset = ZEWDPCBaseDataset(
images_dir="./data/debug/images",
labels_path="./data/debug/labels.csv",
shuffle_seed=DATASET_SHUFFLE_SEED,
)
# Instantiate Unlabelled Dataset
unlabelled_dataset = ZEWDPCProtectedDataset(
images_dir="./data/debug/images",
labels_path="./data/debug/labels.csv",
budget=100, # Configurable Parameter
shuffle_seed=DATASET_SHUFFLE_SEED,
)
# Instantiate Validation Dataset
val_dataset = ZEWDPCBaseDataset(
images_dir="./data/debug/images",
labels_path="./data/debug/labels.csv",
drop_labels=True,
shuffle_seed=DATASET_SHUFFLE_SEED,
)
# A second instantiation of the validation test with the labels present
# - helpful later, when computing the scores.
val_dataset_gt = ZEWDPCBaseDataset(
images_dir="./data/debug/images",
labels_path="./data/debug/labels.csv",
drop_labels=False,
shuffle_seed=DATASET_SHUFFLE_SEED,
)
####################################################################################
####################################################################################
##
## Phase 1 : Pre-Training Phase
####################################################################################
run = ZEWDPCBaseModelRun()
run.pre_training_phase(training_dataset)
run.save_checkpoint("/tmp/pretrainig_phase_checkpoint.pth")
# NOTE:It is critical that the checkpointing works in a self-contained way
# As, the evaluators might choose to run the different phases separately.
del run
####################################################################################
####################################################################################
##
## Phase 2 : Purchase Phase
####################################################################################
run = ZEWDPCBaseModelRun()
run.load_checkpoint("/tmp/pretrainig_phase_checkpoint.pth")
run.purchase_phase(unlabelled_dataset, training_dataset, budget=3000)
run.save_checkpoint("/tmp/purchase_phase_checkpoint.pth")
del run
####################################################################################
####################################################################################
##
## Phase 3 : Prediction Phase
####################################################################################
run = ZEWDPCBaseModelRun()
run.load_checkpoint("/tmp/purchase_phase_checkpoint.pth")
predictions = run.prediction_phase(val_dataset)
assert type(predictions) == np.ndarray
assert predictions.shape == (len(val_dataset), 4)
####################################################################################
####################################################################################
##
## Phase 4 : Evaluation Phase
####################################################################################
from evaluator.evaluation_metrics import (
accuracy_score,
hamming_loss,
exact_match_ratio,
)
y_true = val_dataset_gt._get_all_labels()
y_pred = predictions
accuracy_score = accuracy_score(y_true, y_pred)
hamming_loss_score = hamming_loss(y_true, y_pred)
exact_match_ratio_score = exact_match_ratio(y_true, y_pred)
print("Accuracy Score : ", accuracy_score)
print("Hamming Loss : ", hamming_loss_score)
print("Exact Match Ratio : ", exact_match_ratio_score)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment