diff --git a/tools/train_imagenet/train_imagenet.py b/tools/train_imagenet/train_imagenet.py index 0a9b367d7e3a9ea937a0d6ab5f6c24ffdee10cb3..202344dc54d99138c45a16538865f8c067498589 100644 --- a/tools/train_imagenet/train_imagenet.py +++ b/tools/train_imagenet/train_imagenet.py @@ -197,8 +197,7 @@ def main_worker(gpu, ngpus_per_node, args): traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - + std=[1/255, 1/255, 1/255]) train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ @@ -321,6 +320,7 @@ def validate(val_loader, model, criterion, args): if args.gpu is not None: input = input.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) + input = torch.cat([input[:, 2:3, :, :], input[:, 1:2, :, :], input[:, 0:1, :, :]], dim=1) # compute output output = model(input)