diff --git a/evaluator/dataset.py b/evaluator/dataset.py index 796330fdf86bb5d4c07c92c5872d9958c00ba023..9090f93e0677065c07510ff5e5a6e068df0e6b89 100644 --- a/evaluator/dataset.py +++ b/evaluator/dataset.py @@ -117,7 +117,7 @@ class ZEWDPCBaseDataset(Dataset): Useful during the evaluation for comparison with the ground_truth. """ return self.labels_df[self.labels_column_names].to_numpy() - + def _get_filename(self, idx): row = self._get_row(idx) filename = row.filename.replace(".png", "").replace(".jpg", "") @@ -191,6 +191,7 @@ class ZEWDPCBaseDataset(Dataset): idx = idx.item() return idx + class ZEWDPCProtectedDataset(ZEWDPCBaseDataset): """ A protected Dataset access object which wraps over an `ZEWDPCBaseDataset` object @@ -249,22 +250,19 @@ class ZEWDPCProtectedDataset(ZEWDPCBaseDataset): return self.budget - len(self.purchases) - - if __name__ == "__main__": ########################################################################### ########################################################################### - ## + ## ## BaseDataset Access Examples ########################################################################### - ########################################################################### + ########################################################################### dataset = ZEWDPCBaseDataset( images_dir="./data/prepared/v0.1/dataset_debug", labels_path="./data/prepared/v0.1/dataset_debug/labels.csv", drop_labels=False, ) - print("Labels Dictionary :", dataset.labels_column_names) """ @@ -273,10 +271,10 @@ if __name__ == "__main__": """ for sample in tqdm.tqdm(dataset): """ - Each of the samples will have the following structure : - + Each of the samples will have the following structure : + { - 'idx': 0, + 'idx': 0, 'image': array([[[110, 128, 140], [110, 128, 139], [110, 128, 140], @@ -285,29 +283,28 @@ if __name__ == "__main__": [134, 154, 168], [137, 158, 173]]]), 'label': [0, 0, 1, 1] - } - - where : + } + + where : `idx` : contains the reference id for this image `image` : contains the image as an numpy array loaded by skimage.io.imread - `label` : contains the associated labels for this data point. - The values at each of the indices in the label represent the presence or absence - of the following features : + `label` : contains the associated labels for this data point. + The values at each of the indices in the label represent the presence or absence + of the following features : ['scratch_small', 'scratch_large', 'dent_small', 'dent_large'] - - If `drop_labels` is passed as True during the instantiation of the class, - then the `labels` key is not included in the sample. + + If `drop_labels` is passed as True during the instantiation of the class, + then the `labels` key is not included in the sample. """ print(sample) break - ########################################################################### ########################################################################### - ## + ## ## Protected Dataset Access Examples ########################################################################### - ########################################################################### + ########################################################################### p_dataset = ZEWDPCProtectedDataset( images_dir="./data/prepared/v0.1/dataset_debug", labels_path="./data/prepared/v0.1/dataset_debug/labels.csv", @@ -332,11 +329,11 @@ if __name__ == "__main__": # Labels, instead have to be "purchased" label = p_dataset.purchase_label(idx) print(label, p_dataset.check_available_budget()) - # When the budget for accessing the labels has been exhausted, the - # Protected Dataset will throw an OutOfBudetException. + # When the budget for accessing the labels has been exhausted, the + # Protected Dataset will throw an OutOfBudetException. if idx == 50: - # Example of transform applied to the images in the dataset + # Example of transform applied to the images in the dataset # midway. preprocess = transforms.Compose( [ @@ -348,4 +345,3 @@ if __name__ == "__main__": p_dataset.set_transform(preprocess) print(sample.keys()) input("Transofrm applied. Press any key....") - diff --git a/evaluator/evaluation_metrics.py b/evaluator/evaluation_metrics.py index b8b782966383b5a271a077a643bd427e3fe219f0..111d652a21e09055e000d87e403a14c08cfda486 100644 --- a/evaluator/evaluation_metrics.py +++ b/evaluator/evaluation_metrics.py @@ -1,5 +1,5 @@ import torch -import numpy as np +import numpy as np from sklearn.metrics import accuracy_score from sklearn.metrics import hamming_loss @@ -8,5 +8,5 @@ from sklearn.metrics import hamming_loss def exact_match_ratio(y_true, y_pred): if type(y_pred) == torch.Tensor: y_pred = y_pred.numpy() - - return np.all(y_pred == y_true, axis=1).mean() \ No newline at end of file + + return np.all(y_pred == y_true, axis=1).mean() diff --git a/evaluator/exceptions.py b/evaluator/exceptions.py index b54f5169fcd5076f6bb421827ee8a3da9ebdc00e..a104460f225e2e60a1ba4c826bc2bd2773729067 100644 --- a/evaluator/exceptions.py +++ b/evaluator/exceptions.py @@ -3,9 +3,10 @@ class OutOfBudetException(Exception): """Out of labelling budget""" + def __init__(self, available_budget): self.available_budget = available_budget self.message = "Already Exhausted Label Purchasing Budget of : {}".format( self.available_budget ) - super().__init__(self.message) \ No newline at end of file + super().__init__(self.message)