diff --git a/configs/pascal_voc/ssd300_voc.py b/configs/pascal_voc/ssd300_voc.py index 79d60e95c740636610d266a681ba4f48d6b2c66b..88f662ca3062c04997dbbf15589f0ba7f886a757 100644 --- a/configs/pascal_voc/ssd300_voc.py +++ b/configs/pascal_voc/ssd300_voc.py @@ -1,3 +1,4 @@ +benchmark = True # model settings input_size = 300 model = dict( diff --git a/configs/pascal_voc/ssd512_voc.py b/configs/pascal_voc/ssd512_voc.py index 25e1369b9e84edb1cea7c77bcce1aa12128baa95..c670cc7bd8df887df0f1a0318226dd394cd88266 100644 --- a/configs/pascal_voc/ssd512_voc.py +++ b/configs/pascal_voc/ssd512_voc.py @@ -1,3 +1,4 @@ +benchmark = True # model settings input_size = 512 model = dict( diff --git a/configs/ssd300_coco.py b/configs/ssd300_coco.py index 781d2df0b927791d16f30753a1cfabd6011e5963..1c14cea56f5f99ae132c9b6f400ab2e82c91693d 100644 --- a/configs/ssd300_coco.py +++ b/configs/ssd300_coco.py @@ -1,3 +1,4 @@ +benchmark = True # model settings input_size = 300 model = dict( diff --git a/configs/ssd512_coco.py b/configs/ssd512_coco.py index 1d9b35299eafc8d06b4c70a4ab27b4fe2a981daf..e410b949aee73512467c1673f1f71a3b295a16a1 100644 --- a/configs/ssd512_coco.py +++ b/configs/ssd512_coco.py @@ -1,3 +1,4 @@ +benchmark = True # model settings input_size = 512 model = dict( diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py index 1bc842a843b00986ca9cef9d864f098a34d9627f..2ab3be99d07f6517f9370f43994d22d8d374818b 100644 --- a/mmdet/datasets/custom.py +++ b/mmdet/datasets/custom.py @@ -47,9 +47,9 @@ class CustomDataset(Dataset): with_mask=True, with_crowd=True, with_label=True, - test_mode=False, extra_aug=None, - resize_keep_ratio=True): + resize_keep_ratio=True, + test_mode=False): # prefix of images path self.img_prefix = img_prefix diff --git a/tools/train.py b/tools/train.py index 4ec7e5615430d16c9b6bf0c315dd84df42512c1e..663fdd7a2ce9d9c9846614f68db2fe1a4d09b637 100644 --- a/tools/train.py +++ b/tools/train.py @@ -8,6 +8,7 @@ from mmdet.datasets import get_dataset from mmdet.apis import (train_detector, init_dist, get_root_logger, set_random_seed) from mmdet.models import build_detector +import torch def parse_args(): @@ -42,6 +43,9 @@ def main(): args = parse_args() cfg = Config.fromfile(args.config) + # set benchmark + if cfg.get('benchmark', False): + torch.backends.cudnn.benchmark = True # update configs according to CLI args if args.work_dir is not None: cfg.work_dir = args.work_dir