Skip to content
Snippets Groups Projects
Commit 628441b7 authored by ThangVu's avatar ThangVu
Browse files

caffe2 preprocess in group norm unit test

parent 8500c14e
No related branches found
No related tags found
No related merge requests found
......@@ -197,8 +197,7 @@ def main_worker(gpu, ngpus_per_node, args):
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
std=[1/255, 1/255, 1/255])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
......@@ -321,6 +320,7 @@ def validate(val_loader, model, criterion, args):
if args.gpu is not None:
input = input.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
input = torch.cat([input[:, 2:3, :, :], input[:, 1:2, :, :], input[:, 0:1, :, :]], dim=1)
# compute output
output = model(input)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment