From 825cfa0ce2d29ce7e42ad38f6c068c051bc3e940 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Tue, 11 Dec 2018 21:52:02 +0800
Subject: [PATCH] add class attribute CLASSES to Dataset

---
 mmdet/core/evaluation/class_names.py | 16 ++++++++--------
 mmdet/datasets/coco.py               | 15 +++++++++++++++
 mmdet/datasets/concat_dataset.py     | 14 ++++++++------
 mmdet/datasets/custom.py             |  6 ++++--
 mmdet/datasets/repeat_dataset.py     |  8 +++++---
 mmdet/models/detectors/base.py       |  7 ++++---
 tools/test.py                        |  7 ++++---
 7 files changed, 48 insertions(+), 25 deletions(-)

diff --git a/mmdet/core/evaluation/class_names.py b/mmdet/core/evaluation/class_names.py
index 04f8063..a8d0a3b 100644
--- a/mmdet/core/evaluation/class_names.py
+++ b/mmdet/core/evaluation/class_names.py
@@ -63,18 +63,18 @@ def imagenet_vid_classes():
 def coco_classes():
     return [
         'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
-        'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
-        'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
+        'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign',
+        'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
         'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
         'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
-        'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
-        'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
+        'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard',
+        'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork',
         'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
-        'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
-        'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
-        'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
+        'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair',
+        'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv',
+        'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
         'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
-        'scissors', 'teddy bear', 'hair drier', 'toothbrush'
+        'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'
     ]
 
 
diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py
index 4058db9..886efbd 100644
--- a/mmdet/datasets/coco.py
+++ b/mmdet/datasets/coco.py
@@ -6,6 +6,21 @@ from .custom import CustomDataset
 
 class CocoDataset(CustomDataset):
 
+    CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+               'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
+               'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
+               'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
+               'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
+               'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
+               'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
+               'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
+               'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
+               'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
+               'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
+               'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
+               'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
+               'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush')
+
     def load_annotations(self, ann_file):
         self.coco = COCO(ann_file)
         self.cat_ids = self.coco.getCatIds()
diff --git a/mmdet/datasets/concat_dataset.py b/mmdet/datasets/concat_dataset.py
index e42b609..195420a 100644
--- a/mmdet/datasets/concat_dataset.py
+++ b/mmdet/datasets/concat_dataset.py
@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
 
 
 class ConcatDataset(_ConcatDataset):
-    """
-    Same as torch.utils.data.dataset.ConcatDataset, but
+    """A wrapper of concatenated dataset.
+
+    Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
     concat the group flag for image aspect ratio.
+
+    Args:
+        datasets (list[:obj:`Dataset`]): A list of datasets.
     """
+
     def __init__(self, datasets):
-        """
-        flag: Images with aspect ratio greater than 1 will be set as group 1,
-              otherwise group 0.
-        """
         super(ConcatDataset, self).__init__(datasets)
+        self.CLASSES = datasets[0].CLASSES
         if hasattr(datasets[0], 'flag'):
             flags = []
             for i in range(0, len(datasets)):
diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py
index 3640a83..5cb1771 100644
--- a/mmdet/datasets/custom.py
+++ b/mmdet/datasets/custom.py
@@ -32,6 +32,8 @@ class CustomDataset(Dataset):
     The `ann` field is optional for testing.
     """
 
+    CLASSES = None
+
     def __init__(self,
                  ann_file,
                  img_prefix,
@@ -45,6 +47,8 @@ class CustomDataset(Dataset):
                  with_crowd=True,
                  with_label=True,
                  test_mode=False):
+        # prefix of images path
+        self.img_prefix = img_prefix
         # load annotations (and proposals)
         self.img_infos = self.load_annotations(ann_file)
         if proposal_file is not None:
@@ -58,8 +62,6 @@ class CustomDataset(Dataset):
             if self.proposals is not None:
                 self.proposals = [self.proposals[i] for i in valid_inds]
 
-        # prefix of images path
-        self.img_prefix = img_prefix
         # (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
         self.img_scales = img_scale if isinstance(img_scale,
                                                   list) else [img_scale]
diff --git a/mmdet/datasets/repeat_dataset.py b/mmdet/datasets/repeat_dataset.py
index 5edf33c..7e99293 100644
--- a/mmdet/datasets/repeat_dataset.py
+++ b/mmdet/datasets/repeat_dataset.py
@@ -6,12 +6,14 @@ class RepeatDataset(object):
     def __init__(self, dataset, times):
         self.dataset = dataset
         self.times = times
+        self.CLASSES = dataset.CLASSES
         if hasattr(self.dataset, 'flag'):
             self.flag = np.tile(self.dataset.flag, times)
-        self._original_length = len(self.dataset)
+
+        self._ori_len = len(self.dataset)
 
     def __getitem__(self, idx):
-        return self.dataset[idx % self._original_length]
+        return self.dataset[idx % self._ori_len]
 
     def __len__(self):
-        return self.times * self._original_length
+        return self.times * self._ori_len
diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
index cbaf434..e23784e 100644
--- a/mmdet/models/detectors/base.py
+++ b/mmdet/models/detectors/base.py
@@ -99,11 +99,12 @@ class BaseDetector(nn.Module):
 
         if isinstance(dataset, str):
             class_names = get_classes(dataset)
-        elif isinstance(dataset, list):
+        elif isinstance(dataset, (list, tuple)) or dataset is None:
             class_names = dataset
         else:
-            raise TypeError('dataset must be a valid dataset name or a list'
-                            ' of class names, not {}'.format(type(dataset)))
+            raise TypeError(
+                'dataset must be a valid dataset name or a sequence'
+                ' of class names, not {}'.format(type(dataset)))
 
         for img, img_meta in zip(imgs, img_metas):
             h, w, _ = img_meta['img_shape']
diff --git a/tools/test.py b/tools/test.py
index 9a599c2..9130142 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -14,15 +14,16 @@ from mmdet.models import build_detector, detectors
 def single_test(model, data_loader, show=False):
     model.eval()
     results = []
-    prog_bar = mmcv.ProgressBar(len(data_loader.dataset))
+    dataset = data_loader.dataset
+    prog_bar = mmcv.ProgressBar(len(dataset))
     for i, data in enumerate(data_loader):
         with torch.no_grad():
             result = model(return_loss=False, rescale=not show, **data)
         results.append(result)
 
         if show:
-            model.module.show_result(data, result,
-                                     data_loader.dataset.img_norm_cfg)
+            model.module.show_result(data, result, dataset.img_norm_cfg,
+                                     dataset.CLASSES)
 
         batch_size = data['img'][0].size(0)
         for _ in range(batch_size):
-- 
GitLab