Skip to content
Snippets Groups Projects
Commit 3ce894b2 authored by Jyotish P's avatar Jyotish P
Browse files

Format evaluator files using black

parent bd08e564
No related branches found
No related tags found
1 merge request!1Restructure code
...@@ -117,7 +117,7 @@ class ZEWDPCBaseDataset(Dataset): ...@@ -117,7 +117,7 @@ class ZEWDPCBaseDataset(Dataset):
Useful during the evaluation for comparison with the ground_truth. Useful during the evaluation for comparison with the ground_truth.
""" """
return self.labels_df[self.labels_column_names].to_numpy() return self.labels_df[self.labels_column_names].to_numpy()
def _get_filename(self, idx): def _get_filename(self, idx):
row = self._get_row(idx) row = self._get_row(idx)
filename = row.filename.replace(".png", "").replace(".jpg", "") filename = row.filename.replace(".png", "").replace(".jpg", "")
...@@ -191,6 +191,7 @@ class ZEWDPCBaseDataset(Dataset): ...@@ -191,6 +191,7 @@ class ZEWDPCBaseDataset(Dataset):
idx = idx.item() idx = idx.item()
return idx return idx
class ZEWDPCProtectedDataset(ZEWDPCBaseDataset): class ZEWDPCProtectedDataset(ZEWDPCBaseDataset):
""" """
A protected Dataset access object which wraps over an `ZEWDPCBaseDataset` object A protected Dataset access object which wraps over an `ZEWDPCBaseDataset` object
...@@ -249,22 +250,19 @@ class ZEWDPCProtectedDataset(ZEWDPCBaseDataset): ...@@ -249,22 +250,19 @@ class ZEWDPCProtectedDataset(ZEWDPCBaseDataset):
return self.budget - len(self.purchases) return self.budget - len(self.purchases)
if __name__ == "__main__": if __name__ == "__main__":
########################################################################### ###########################################################################
########################################################################### ###########################################################################
## ##
## BaseDataset Access Examples ## BaseDataset Access Examples
########################################################################### ###########################################################################
########################################################################### ###########################################################################
dataset = ZEWDPCBaseDataset( dataset = ZEWDPCBaseDataset(
images_dir="./data/prepared/v0.1/dataset_debug", images_dir="./data/prepared/v0.1/dataset_debug",
labels_path="./data/prepared/v0.1/dataset_debug/labels.csv", labels_path="./data/prepared/v0.1/dataset_debug/labels.csv",
drop_labels=False, drop_labels=False,
) )
print("Labels Dictionary :", dataset.labels_column_names) print("Labels Dictionary :", dataset.labels_column_names)
""" """
...@@ -273,10 +271,10 @@ if __name__ == "__main__": ...@@ -273,10 +271,10 @@ if __name__ == "__main__":
""" """
for sample in tqdm.tqdm(dataset): for sample in tqdm.tqdm(dataset):
""" """
Each of the samples will have the following structure : Each of the samples will have the following structure :
{ {
'idx': 0, 'idx': 0,
'image': array([[[110, 128, 140], 'image': array([[[110, 128, 140],
[110, 128, 139], [110, 128, 139],
[110, 128, 140], [110, 128, 140],
...@@ -285,29 +283,28 @@ if __name__ == "__main__": ...@@ -285,29 +283,28 @@ if __name__ == "__main__":
[134, 154, 168], [134, 154, 168],
[137, 158, 173]]]), [137, 158, 173]]]),
'label': [0, 0, 1, 1] 'label': [0, 0, 1, 1]
} }
where : where :
`idx` : contains the reference id for this image `idx` : contains the reference id for this image
`image` : contains the image as an numpy array loaded by skimage.io.imread `image` : contains the image as an numpy array loaded by skimage.io.imread
`label` : contains the associated labels for this data point. `label` : contains the associated labels for this data point.
The values at each of the indices in the label represent the presence or absence The values at each of the indices in the label represent the presence or absence
of the following features : of the following features :
['scratch_small', 'scratch_large', 'dent_small', 'dent_large'] ['scratch_small', 'scratch_large', 'dent_small', 'dent_large']
If `drop_labels` is passed as True during the instantiation of the class, If `drop_labels` is passed as True during the instantiation of the class,
then the `labels` key is not included in the sample. then the `labels` key is not included in the sample.
""" """
print(sample) print(sample)
break break
########################################################################### ###########################################################################
########################################################################### ###########################################################################
## ##
## Protected Dataset Access Examples ## Protected Dataset Access Examples
########################################################################### ###########################################################################
########################################################################### ###########################################################################
p_dataset = ZEWDPCProtectedDataset( p_dataset = ZEWDPCProtectedDataset(
images_dir="./data/prepared/v0.1/dataset_debug", images_dir="./data/prepared/v0.1/dataset_debug",
labels_path="./data/prepared/v0.1/dataset_debug/labels.csv", labels_path="./data/prepared/v0.1/dataset_debug/labels.csv",
...@@ -332,11 +329,11 @@ if __name__ == "__main__": ...@@ -332,11 +329,11 @@ if __name__ == "__main__":
# Labels, instead have to be "purchased" # Labels, instead have to be "purchased"
label = p_dataset.purchase_label(idx) label = p_dataset.purchase_label(idx)
print(label, p_dataset.check_available_budget()) print(label, p_dataset.check_available_budget())
# When the budget for accessing the labels has been exhausted, the # When the budget for accessing the labels has been exhausted, the
# Protected Dataset will throw an OutOfBudetException. # Protected Dataset will throw an OutOfBudetException.
if idx == 50: if idx == 50:
# Example of transform applied to the images in the dataset # Example of transform applied to the images in the dataset
# midway. # midway.
preprocess = transforms.Compose( preprocess = transforms.Compose(
[ [
...@@ -348,4 +345,3 @@ if __name__ == "__main__": ...@@ -348,4 +345,3 @@ if __name__ == "__main__":
p_dataset.set_transform(preprocess) p_dataset.set_transform(preprocess)
print(sample.keys()) print(sample.keys())
input("Transofrm applied. Press any key....") input("Transofrm applied. Press any key....")
import torch import torch
import numpy as np import numpy as np
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from sklearn.metrics import hamming_loss from sklearn.metrics import hamming_loss
...@@ -8,5 +8,5 @@ from sklearn.metrics import hamming_loss ...@@ -8,5 +8,5 @@ from sklearn.metrics import hamming_loss
def exact_match_ratio(y_true, y_pred): def exact_match_ratio(y_true, y_pred):
if type(y_pred) == torch.Tensor: if type(y_pred) == torch.Tensor:
y_pred = y_pred.numpy() y_pred = y_pred.numpy()
return np.all(y_pred == y_true, axis=1).mean() return np.all(y_pred == y_true, axis=1).mean()
\ No newline at end of file
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
class OutOfBudetException(Exception): class OutOfBudetException(Exception):
"""Out of labelling budget""" """Out of labelling budget"""
def __init__(self, available_budget): def __init__(self, available_budget):
self.available_budget = available_budget self.available_budget = available_budget
self.message = "Already Exhausted Label Purchasing Budget of : {}".format( self.message = "Already Exhausted Label Purchasing Budget of : {}".format(
self.available_budget self.available_budget
) )
super().__init__(self.message) super().__init__(self.message)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment