From dc5edc3843c8bc8bc00876a9322dd409bdae6565 Mon Sep 17 00:00:00 2001 From: yhcao6 <yhcao6@gmail.com> Date: Fri, 21 Dec 2018 15:47:41 +0800 Subject: [PATCH] add benchmark set, reorder parameter of custom dataset --- configs/pascal_voc/ssd300_voc.py | 1 + configs/pascal_voc/ssd512_voc.py | 1 + configs/ssd300_coco.py | 1 + configs/ssd512_coco.py | 1 + mmdet/datasets/custom.py | 4 ++-- tools/train.py | 4 ++++ 6 files changed, 10 insertions(+), 2 deletions(-) diff --git a/configs/pascal_voc/ssd300_voc.py b/configs/pascal_voc/ssd300_voc.py index 79d60e9..88f662c 100644 --- a/configs/pascal_voc/ssd300_voc.py +++ b/configs/pascal_voc/ssd300_voc.py @@ -1,3 +1,4 @@ +benchmark = True # model settings input_size = 300 model = dict( diff --git a/configs/pascal_voc/ssd512_voc.py b/configs/pascal_voc/ssd512_voc.py index 25e1369..c670cc7 100644 --- a/configs/pascal_voc/ssd512_voc.py +++ b/configs/pascal_voc/ssd512_voc.py @@ -1,3 +1,4 @@ +benchmark = True # model settings input_size = 512 model = dict( diff --git a/configs/ssd300_coco.py b/configs/ssd300_coco.py index 781d2df..1c14cea 100644 --- a/configs/ssd300_coco.py +++ b/configs/ssd300_coco.py @@ -1,3 +1,4 @@ +benchmark = True # model settings input_size = 300 model = dict( diff --git a/configs/ssd512_coco.py b/configs/ssd512_coco.py index 1d9b352..e410b94 100644 --- a/configs/ssd512_coco.py +++ b/configs/ssd512_coco.py @@ -1,3 +1,4 @@ +benchmark = True # model settings input_size = 512 model = dict( diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py index 1bc842a..2ab3be9 100644 --- a/mmdet/datasets/custom.py +++ b/mmdet/datasets/custom.py @@ -47,9 +47,9 @@ class CustomDataset(Dataset): with_mask=True, with_crowd=True, with_label=True, - test_mode=False, extra_aug=None, - resize_keep_ratio=True): + resize_keep_ratio=True, + test_mode=False): # prefix of images path self.img_prefix = img_prefix diff --git a/tools/train.py b/tools/train.py index 4ec7e56..663fdd7 100644 --- a/tools/train.py +++ b/tools/train.py @@ -8,6 +8,7 @@ from mmdet.datasets import get_dataset from mmdet.apis import (train_detector, init_dist, get_root_logger, set_random_seed) from mmdet.models import build_detector +import torch def parse_args(): @@ -42,6 +43,9 @@ def main(): args = parse_args() cfg = Config.fromfile(args.config) + # set benchmark + if cfg.get('benchmark', False): + torch.backends.cudnn.benchmark = True # update configs according to CLI args if args.work_dir is not None: cfg.work_dir = args.work_dir -- GitLab