diff --git a/mmdet/core/evaluation/class_names.py b/mmdet/core/evaluation/class_names.py index 04f806315b7c6ef47419efa61e38d2f7ec3ebd2a..a8d0a3b9b2c4acf1ade04cafc16f48fd0eb743d7 100644 --- a/mmdet/core/evaluation/class_names.py +++ b/mmdet/core/evaluation/class_names.py @@ -63,18 +63,18 @@ def imagenet_vid_classes(): def coco_classes(): return [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign', + 'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', - 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard', + 'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', - 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', - 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', - 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', - 'scissors', 'teddy bear', 'hair drier', 'toothbrush' + 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush' ] diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index 4058db90be6b50336af66fa14270d7dd0d16f882..886efbd03d3109b45281fc9015df0f5847739bb5 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -6,6 +6,21 @@ from .custom import CustomDataset class CocoDataset(CustomDataset): + CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant', + 'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat', + 'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket', + 'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush') + def load_annotations(self, ann_file): self.coco = COCO(ann_file) self.cat_ids = self.coco.getCatIds() diff --git a/mmdet/datasets/concat_dataset.py b/mmdet/datasets/concat_dataset.py index e42b60982578550774ab930832f7ee1014ef672f..195420ad9d3899f07f829e0eda56df370a95aeee 100644 --- a/mmdet/datasets/concat_dataset.py +++ b/mmdet/datasets/concat_dataset.py @@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset class ConcatDataset(_ConcatDataset): - """ - Same as torch.utils.data.dataset.ConcatDataset, but + """A wrapper of concatenated dataset. + + Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but concat the group flag for image aspect ratio. + + Args: + datasets (list[:obj:`Dataset`]): A list of datasets. """ + def __init__(self, datasets): - """ - flag: Images with aspect ratio greater than 1 will be set as group 1, - otherwise group 0. - """ super(ConcatDataset, self).__init__(datasets) + self.CLASSES = datasets[0].CLASSES if hasattr(datasets[0], 'flag'): flags = [] for i in range(0, len(datasets)): diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py index 3640a83db9ac10504ea6076ceff88bf1cdfa3ec8..5cb17716262760f1dee345aca4381ae22c36368b 100644 --- a/mmdet/datasets/custom.py +++ b/mmdet/datasets/custom.py @@ -32,6 +32,8 @@ class CustomDataset(Dataset): The `ann` field is optional for testing. """ + CLASSES = None + def __init__(self, ann_file, img_prefix, @@ -45,6 +47,8 @@ class CustomDataset(Dataset): with_crowd=True, with_label=True, test_mode=False): + # prefix of images path + self.img_prefix = img_prefix # load annotations (and proposals) self.img_infos = self.load_annotations(ann_file) if proposal_file is not None: @@ -58,8 +62,6 @@ class CustomDataset(Dataset): if self.proposals is not None: self.proposals = [self.proposals[i] for i in valid_inds] - # prefix of images path - self.img_prefix = img_prefix # (long_edge, short_edge) or [(long1, short1), (long2, short2), ...] self.img_scales = img_scale if isinstance(img_scale, list) else [img_scale] diff --git a/mmdet/datasets/repeat_dataset.py b/mmdet/datasets/repeat_dataset.py index 5edf33c7417a917287469e36c9c839d882adf8dd..7e9929332c437705a996d7a1109c030df6d63251 100644 --- a/mmdet/datasets/repeat_dataset.py +++ b/mmdet/datasets/repeat_dataset.py @@ -6,12 +6,14 @@ class RepeatDataset(object): def __init__(self, dataset, times): self.dataset = dataset self.times = times + self.CLASSES = dataset.CLASSES if hasattr(self.dataset, 'flag'): self.flag = np.tile(self.dataset.flag, times) - self._original_length = len(self.dataset) + + self._ori_len = len(self.dataset) def __getitem__(self, idx): - return self.dataset[idx % self._original_length] + return self.dataset[idx % self._ori_len] def __len__(self): - return self.times * self._original_length + return self.times * self._ori_len diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index cbaf4349d2da9c9bec2d7fe3847157c77463faef..e23784ea483e0f7ac30ed22c8c76949dc40424ff 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -99,11 +99,12 @@ class BaseDetector(nn.Module): if isinstance(dataset, str): class_names = get_classes(dataset) - elif isinstance(dataset, list): + elif isinstance(dataset, (list, tuple)) or dataset is None: class_names = dataset else: - raise TypeError('dataset must be a valid dataset name or a list' - ' of class names, not {}'.format(type(dataset))) + raise TypeError( + 'dataset must be a valid dataset name or a sequence' + ' of class names, not {}'.format(type(dataset))) for img, img_meta in zip(imgs, img_metas): h, w, _ = img_meta['img_shape'] diff --git a/tools/test.py b/tools/test.py index 9a599c2d923f5e6d999363d21f120c0b38f71395..91301421b955b500877f539f24e7fcc4750fedda 100644 --- a/tools/test.py +++ b/tools/test.py @@ -14,15 +14,16 @@ from mmdet.models import build_detector, detectors def single_test(model, data_loader, show=False): model.eval() results = [] - prog_bar = mmcv.ProgressBar(len(data_loader.dataset)) + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) for i, data in enumerate(data_loader): with torch.no_grad(): result = model(return_loss=False, rescale=not show, **data) results.append(result) if show: - model.module.show_result(data, result, - data_loader.dataset.img_norm_cfg) + model.module.show_result(data, result, dataset.img_norm_cfg, + dataset.CLASSES) batch_size = data['img'][0].size(0) for _ in range(batch_size):