From 825cfa0ce2d29ce7e42ad38f6c068c051bc3e940 Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Tue, 11 Dec 2018 21:52:02 +0800 Subject: [PATCH] add class attribute CLASSES to Dataset --- mmdet/core/evaluation/class_names.py | 16 ++++++++-------- mmdet/datasets/coco.py | 15 +++++++++++++++ mmdet/datasets/concat_dataset.py | 14 ++++++++------ mmdet/datasets/custom.py | 6 ++++-- mmdet/datasets/repeat_dataset.py | 8 +++++--- mmdet/models/detectors/base.py | 7 ++++--- tools/test.py | 7 ++++--- 7 files changed, 48 insertions(+), 25 deletions(-) diff --git a/mmdet/core/evaluation/class_names.py b/mmdet/core/evaluation/class_names.py index 04f8063..a8d0a3b 100644 --- a/mmdet/core/evaluation/class_names.py +++ b/mmdet/core/evaluation/class_names.py @@ -63,18 +63,18 @@ def imagenet_vid_classes(): def coco_classes(): return [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign', + 'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', - 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard', + 'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', - 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', - 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', - 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', - 'scissors', 'teddy bear', 'hair drier', 'toothbrush' + 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush' ] diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index 4058db9..886efbd 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -6,6 +6,21 @@ from .custom import CustomDataset class CocoDataset(CustomDataset): + CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant', + 'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat', + 'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket', + 'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush') + def load_annotations(self, ann_file): self.coco = COCO(ann_file) self.cat_ids = self.coco.getCatIds() diff --git a/mmdet/datasets/concat_dataset.py b/mmdet/datasets/concat_dataset.py index e42b609..195420a 100644 --- a/mmdet/datasets/concat_dataset.py +++ b/mmdet/datasets/concat_dataset.py @@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset class ConcatDataset(_ConcatDataset): - """ - Same as torch.utils.data.dataset.ConcatDataset, but + """A wrapper of concatenated dataset. + + Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but concat the group flag for image aspect ratio. + + Args: + datasets (list[:obj:`Dataset`]): A list of datasets. """ + def __init__(self, datasets): - """ - flag: Images with aspect ratio greater than 1 will be set as group 1, - otherwise group 0. - """ super(ConcatDataset, self).__init__(datasets) + self.CLASSES = datasets[0].CLASSES if hasattr(datasets[0], 'flag'): flags = [] for i in range(0, len(datasets)): diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py index 3640a83..5cb1771 100644 --- a/mmdet/datasets/custom.py +++ b/mmdet/datasets/custom.py @@ -32,6 +32,8 @@ class CustomDataset(Dataset): The `ann` field is optional for testing. """ + CLASSES = None + def __init__(self, ann_file, img_prefix, @@ -45,6 +47,8 @@ class CustomDataset(Dataset): with_crowd=True, with_label=True, test_mode=False): + # prefix of images path + self.img_prefix = img_prefix # load annotations (and proposals) self.img_infos = self.load_annotations(ann_file) if proposal_file is not None: @@ -58,8 +62,6 @@ class CustomDataset(Dataset): if self.proposals is not None: self.proposals = [self.proposals[i] for i in valid_inds] - # prefix of images path - self.img_prefix = img_prefix # (long_edge, short_edge) or [(long1, short1), (long2, short2), ...] self.img_scales = img_scale if isinstance(img_scale, list) else [img_scale] diff --git a/mmdet/datasets/repeat_dataset.py b/mmdet/datasets/repeat_dataset.py index 5edf33c..7e99293 100644 --- a/mmdet/datasets/repeat_dataset.py +++ b/mmdet/datasets/repeat_dataset.py @@ -6,12 +6,14 @@ class RepeatDataset(object): def __init__(self, dataset, times): self.dataset = dataset self.times = times + self.CLASSES = dataset.CLASSES if hasattr(self.dataset, 'flag'): self.flag = np.tile(self.dataset.flag, times) - self._original_length = len(self.dataset) + + self._ori_len = len(self.dataset) def __getitem__(self, idx): - return self.dataset[idx % self._original_length] + return self.dataset[idx % self._ori_len] def __len__(self): - return self.times * self._original_length + return self.times * self._ori_len diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index cbaf434..e23784e 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -99,11 +99,12 @@ class BaseDetector(nn.Module): if isinstance(dataset, str): class_names = get_classes(dataset) - elif isinstance(dataset, list): + elif isinstance(dataset, (list, tuple)) or dataset is None: class_names = dataset else: - raise TypeError('dataset must be a valid dataset name or a list' - ' of class names, not {}'.format(type(dataset))) + raise TypeError( + 'dataset must be a valid dataset name or a sequence' + ' of class names, not {}'.format(type(dataset))) for img, img_meta in zip(imgs, img_metas): h, w, _ = img_meta['img_shape'] diff --git a/tools/test.py b/tools/test.py index 9a599c2..9130142 100644 --- a/tools/test.py +++ b/tools/test.py @@ -14,15 +14,16 @@ from mmdet.models import build_detector, detectors def single_test(model, data_loader, show=False): model.eval() results = [] - prog_bar = mmcv.ProgressBar(len(data_loader.dataset)) + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) for i, data in enumerate(data_loader): with torch.no_grad(): result = model(return_loss=False, rescale=not show, **data) results.append(result) if show: - model.module.show_result(data, result, - data_loader.dataset.img_norm_cfg) + model.module.show_result(data, result, dataset.img_norm_cfg, + dataset.CLASSES) batch_size = data['img'][0].size(0) for _ in range(batch_size): -- GitLab