From dc5edc3843c8bc8bc00876a9322dd409bdae6565 Mon Sep 17 00:00:00 2001
From: yhcao6 <yhcao6@gmail.com>
Date: Fri, 21 Dec 2018 15:47:41 +0800
Subject: [PATCH] add benchmark set, reorder parameter of custom dataset

---
 configs/pascal_voc/ssd300_voc.py | 1 +
 configs/pascal_voc/ssd512_voc.py | 1 +
 configs/ssd300_coco.py           | 1 +
 configs/ssd512_coco.py           | 1 +
 mmdet/datasets/custom.py         | 4 ++--
 tools/train.py                   | 4 ++++
 6 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/configs/pascal_voc/ssd300_voc.py b/configs/pascal_voc/ssd300_voc.py
index 79d60e9..88f662c 100644
--- a/configs/pascal_voc/ssd300_voc.py
+++ b/configs/pascal_voc/ssd300_voc.py
@@ -1,3 +1,4 @@
+benchmark = True
 # model settings
 input_size = 300
 model = dict(
diff --git a/configs/pascal_voc/ssd512_voc.py b/configs/pascal_voc/ssd512_voc.py
index 25e1369..c670cc7 100644
--- a/configs/pascal_voc/ssd512_voc.py
+++ b/configs/pascal_voc/ssd512_voc.py
@@ -1,3 +1,4 @@
+benchmark = True
 # model settings
 input_size = 512
 model = dict(
diff --git a/configs/ssd300_coco.py b/configs/ssd300_coco.py
index 781d2df..1c14cea 100644
--- a/configs/ssd300_coco.py
+++ b/configs/ssd300_coco.py
@@ -1,3 +1,4 @@
+benchmark = True
 # model settings
 input_size = 300
 model = dict(
diff --git a/configs/ssd512_coco.py b/configs/ssd512_coco.py
index 1d9b352..e410b94 100644
--- a/configs/ssd512_coco.py
+++ b/configs/ssd512_coco.py
@@ -1,3 +1,4 @@
+benchmark = True
 # model settings
 input_size = 512
 model = dict(
diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py
index 1bc842a..2ab3be9 100644
--- a/mmdet/datasets/custom.py
+++ b/mmdet/datasets/custom.py
@@ -47,9 +47,9 @@ class CustomDataset(Dataset):
                  with_mask=True,
                  with_crowd=True,
                  with_label=True,
-                 test_mode=False,
                  extra_aug=None,
-                 resize_keep_ratio=True):
+                 resize_keep_ratio=True,
+                 test_mode=False):
         # prefix of images path
         self.img_prefix = img_prefix
 
diff --git a/tools/train.py b/tools/train.py
index 4ec7e56..663fdd7 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -8,6 +8,7 @@ from mmdet.datasets import get_dataset
 from mmdet.apis import (train_detector, init_dist, get_root_logger,
                         set_random_seed)
 from mmdet.models import build_detector
+import torch
 
 
 def parse_args():
@@ -42,6 +43,9 @@ def main():
     args = parse_args()
 
     cfg = Config.fromfile(args.config)
+    # set benchmark
+    if cfg.get('benchmark', False):
+        torch.backends.cudnn.benchmark = True
     # update configs according to CLI args
     if args.work_dir is not None:
         cfg.work_dir = args.work_dir
-- 
GitLab