From 7906bd2064d8511063d917d0d066d346d1edcaf2 Mon Sep 17 00:00:00 2001 From: wangg12 <guwang12@gmail.com> Date: Wed, 28 Nov 2018 11:32:47 +0800 Subject: [PATCH] support training on dataset with multiple ann_files --- mmdet/datasets/__init__.py | 3 ++- mmdet/datasets/concat_dataset.py | 30 ++++++++++++++++++++++++++++++ tools/train.py | 22 +++++++++++++++++++++- 3 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 mmdet/datasets/concat_dataset.py diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 75e0709..a849a2a 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -2,8 +2,9 @@ from .custom import CustomDataset from .coco import CocoDataset from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .utils import to_tensor, random_scale, show_ann +from .concat_dataset import ConcatDataset __all__ = [ - 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', + 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann' ] diff --git a/mmdet/datasets/concat_dataset.py b/mmdet/datasets/concat_dataset.py new file mode 100644 index 0000000..47e073a --- /dev/null +++ b/mmdet/datasets/concat_dataset.py @@ -0,0 +1,30 @@ +import bisect +import numpy as np +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + + +class ConcatDataset(_ConcatDataset): + """ + Same as torch.utils.data.dataset.ConcatDataset, but + concat the group flag for image aspect ratio. + """ + def __init__(self, datasets): + """ + flag: Images with aspect ratio greater than 1 will be set as group 1, + otherwise group 0. + """ + super(ConcatDataset, self).__init__(datasets) + if hasattr(datasets[0], 'flag'): + flags = [] + for i in range(0, len(datasets)): + flags.append(datasets[i].flag) + self.flag = np.concatenate(flags) + + def get_idxs(self, idx): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return dataset_idx, sample_idx + diff --git a/tools/train.py b/tools/train.py index 8e03628..c1e84b1 100644 --- a/tools/train.py +++ b/tools/train.py @@ -1,10 +1,12 @@ from __future__ import division import argparse +import copy from mmcv import Config from mmcv.runner import obj_from_dict from mmdet import datasets, __version__ +from mmdet.datasets import ConcatDataset from mmdet.apis import (train_detector, init_dist, get_root_logger, set_random_seed) from mmdet.models import build_detector @@ -36,6 +38,24 @@ def parse_args(): return args +def get_train_dataset(cfg): + if isinstance(cfg.data.train['ann_file'], list) or isinstance(cfg.data.train['ann_file'], tuple): + ann_files = cfg.data.train['ann_file'] + train_datasets = [] + for ann_file in ann_files: + data_info = copy.deepcopy(cfg.data.train) + data_info['ann_file'] = ann_file + train_dset = obj_from_dict(data_info, datasets) + train_datasets.append(train_dset) + if len(train_datasets) > 1: + train_dataset = ConcatDataset(train_datasets) + else: + train_dataset = train_datasets[0] + else: + train_dataset = obj_from_dict(cfg.data.train, datasets) + return train_dataset + + def main(): args = parse_args() @@ -67,7 +87,7 @@ def main(): model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) - train_dataset = obj_from_dict(cfg.data.train, datasets) + train_dataset = get_train_dataset(cfg) train_detector( model, train_dataset, -- GitLab