diff --git a/.gitignore b/.gitignore index 01c47d6e277dba0d7b880dff88f9695f9a8eec50..f189e1d5b6851047dc4230032c134f8134a8d73a 100644 --- a/.gitignore +++ b/.gitignore @@ -107,3 +107,5 @@ venv.bak/ mmdet/ops/nms/*.cpp mmdet/version.py data +.vscode +.idea diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 75e07097756bd014bbef17294b6803aa83621fd1..dda1a1eb4a514b93254623693753ecf8f3839c12 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -1,9 +1,11 @@ from .custom import CustomDataset from .coco import CocoDataset from .loader import GroupSampler, DistributedGroupSampler, build_dataloader -from .utils import to_tensor, random_scale, show_ann +from .utils import to_tensor, random_scale, show_ann, get_dataset +from .concat_dataset import ConcatDataset __all__ = [ 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', - 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann' + 'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale', + 'show_ann', 'get_dataset' ] diff --git a/mmdet/datasets/concat_dataset.py b/mmdet/datasets/concat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e42b60982578550774ab930832f7ee1014ef672f --- /dev/null +++ b/mmdet/datasets/concat_dataset.py @@ -0,0 +1,20 @@ +import numpy as np +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + + +class ConcatDataset(_ConcatDataset): + """ + Same as torch.utils.data.dataset.ConcatDataset, but + concat the group flag for image aspect ratio. + """ + def __init__(self, datasets): + """ + flag: Images with aspect ratio greater than 1 will be set as group 1, + otherwise group 0. + """ + super(ConcatDataset, self).__init__(datasets) + if hasattr(datasets[0], 'flag'): + flags = [] + for i in range(0, len(datasets)): + flags.append(datasets[i].flag) + self.flag = np.concatenate(flags) diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py index 5a248ef6890ea348ea7ad98154cc163ae1e035c5..6f6a0b514d141889c8a5701214ca963efe6b0a61 100644 --- a/mmdet/datasets/utils.py +++ b/mmdet/datasets/utils.py @@ -1,10 +1,14 @@ +import copy from collections import Sequence import mmcv +from mmcv.runner import obj_from_dict import torch import matplotlib.pyplot as plt import numpy as np +from .concat_dataset import ConcatDataset +from .. import datasets def to_tensor(data): @@ -67,3 +71,41 @@ def show_ann(coco, img, ann_info): plt.axis('off') coco.showAnns(ann_info) plt.show() + + +def get_dataset(data_cfg): + if isinstance(data_cfg['ann_file'], (list, tuple)): + ann_files = data_cfg['ann_file'] + num_dset = len(ann_files) + else: + ann_files = [data_cfg['ann_file']] + num_dset = 1 + + if 'proposal_file' in data_cfg.keys(): + if isinstance(data_cfg['proposal_file'], (list, tuple)): + proposal_files = data_cfg['proposal_file'] + else: + proposal_files = [data_cfg['proposal_file']] + else: + proposal_files = [None] * num_dset + assert len(proposal_files) == num_dset + + if isinstance(data_cfg['img_prefix'], (list, tuple)): + img_prefixes = data_cfg['img_prefix'] + else: + img_prefixes = [data_cfg['img_prefix']] * num_dset + assert len(img_prefixes) == num_dset + + dsets = [] + for i in range(num_dset): + data_info = copy.deepcopy(data_cfg) + data_info['ann_file'] = ann_files[i] + data_info['proposal_file'] = proposal_files[i] + data_info['img_prefix'] = img_prefixes[i] + dset = obj_from_dict(data_info, datasets) + dsets.append(dset) + if len(dsets) > 1: + dset = ConcatDataset(dsets) + else: + dset = dsets[0] + return dset diff --git a/tools/train.py b/tools/train.py index 8e03628db5ea28d027ccdc3939c72bace482be93..bd47e66bed121629368a86cd9d809a6cfe2b9363 100644 --- a/tools/train.py +++ b/tools/train.py @@ -2,9 +2,9 @@ from __future__ import division import argparse from mmcv import Config -from mmcv.runner import obj_from_dict -from mmdet import datasets, __version__ +from mmdet import __version__ +from mmdet.datasets import get_dataset from mmdet.apis import (train_detector, init_dist, get_root_logger, set_random_seed) from mmdet.models import build_detector @@ -67,7 +67,7 @@ def main(): model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) - train_dataset = obj_from_dict(cfg.data.train, datasets) + train_dataset = get_dataset(cfg.data.train) train_detector( model, train_dataset,