Skip to content
Snippets Groups Projects
Commit 7906bd20 authored by wangg12's avatar wangg12
Browse files

support training on dataset with multiple ann_files

parent 64e310d5
No related branches found
No related tags found
No related merge requests found
......@@ -2,8 +2,9 @@ from .custom import CustomDataset
from .coco import CocoDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann
from .concat_dataset import ConcatDataset
__all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset',
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann'
]
import bisect
import numpy as np
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class ConcatDataset(_ConcatDataset):
"""
Same as torch.utils.data.dataset.ConcatDataset, but
concat the group flag for image aspect ratio.
"""
def __init__(self, datasets):
"""
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super(ConcatDataset, self).__init__(datasets)
if hasattr(datasets[0], 'flag'):
flags = []
for i in range(0, len(datasets)):
flags.append(datasets[i].flag)
self.flag = np.concatenate(flags)
def get_idxs(self, idx):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx
from __future__ import division
import argparse
import copy
from mmcv import Config
from mmcv.runner import obj_from_dict
from mmdet import datasets, __version__
from mmdet.datasets import ConcatDataset
from mmdet.apis import (train_detector, init_dist, get_root_logger,
set_random_seed)
from mmdet.models import build_detector
......@@ -36,6 +38,24 @@ def parse_args():
return args
def get_train_dataset(cfg):
if isinstance(cfg.data.train['ann_file'], list) or isinstance(cfg.data.train['ann_file'], tuple):
ann_files = cfg.data.train['ann_file']
train_datasets = []
for ann_file in ann_files:
data_info = copy.deepcopy(cfg.data.train)
data_info['ann_file'] = ann_file
train_dset = obj_from_dict(data_info, datasets)
train_datasets.append(train_dset)
if len(train_datasets) > 1:
train_dataset = ConcatDataset(train_datasets)
else:
train_dataset = train_datasets[0]
else:
train_dataset = obj_from_dict(cfg.data.train, datasets)
return train_dataset
def main():
args = parse_args()
......@@ -67,7 +87,7 @@ def main():
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
train_dataset = obj_from_dict(cfg.data.train, datasets)
train_dataset = get_train_dataset(cfg)
train_detector(
model,
train_dataset,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment