From 7cbdbc78111c366a6dd615657c2bc2f28394df04 Mon Sep 17 00:00:00 2001 From: wangg12 <guwang12@gmail.com> Date: Wed, 28 Nov 2018 11:54:08 +0800 Subject: [PATCH] move the function to datasets.utils --- mmdet/datasets/__init__.py | 4 ++-- mmdet/datasets/utils.py | 24 ++++++++++++++++++++++-- tools/train.py | 22 +--------------------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index a849a2a..b478acf 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -1,10 +1,10 @@ from .custom import CustomDataset from .coco import CocoDataset from .loader import GroupSampler, DistributedGroupSampler, build_dataloader -from .utils import to_tensor, random_scale, show_ann +from .utils import to_tensor, random_scale, show_ann, get_dataset from .concat_dataset import ConcatDataset __all__ = [ 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset', - 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann' + 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'get_dataset' ] diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py index 5a248ef..d13ae89 100644 --- a/mmdet/datasets/utils.py +++ b/mmdet/datasets/utils.py @@ -1,11 +1,13 @@ from collections import Sequence - +import copy import mmcv +from mmcv.runner import obj_from_dict import torch import matplotlib.pyplot as plt import numpy as np - +from .concat_dataset import ConcatDataset +from .. import datasets def to_tensor(data): """Convert objects of various python types to :obj:`torch.Tensor`. @@ -67,3 +69,21 @@ def show_ann(coco, img, ann_info): plt.axis('off') coco.showAnns(ann_info) plt.show() + + +def get_dataset(data_cfg): + if isinstance(data_cfg['ann_file'], list) or isinstance(data_cfg['ann_file'], tuple): + ann_files = data_cfg['ann_file'] + dsets = [] + for ann_file in ann_files: + data_info = copy.deepcopy(data_cfg) + data_info['ann_file'] = ann_file + dset = obj_from_dict(data_info, datasets) + dsets.append(dset) + if len(dsets) > 1: + dset = ConcatDataset(dsets) + else: + dset = dsets[0] + else: + dset = obj_from_dict(data_cfg, datasets) + return dset \ No newline at end of file diff --git a/tools/train.py b/tools/train.py index c1e84b1..006fc11 100644 --- a/tools/train.py +++ b/tools/train.py @@ -1,12 +1,10 @@ from __future__ import division import argparse -import copy from mmcv import Config from mmcv.runner import obj_from_dict from mmdet import datasets, __version__ -from mmdet.datasets import ConcatDataset from mmdet.apis import (train_detector, init_dist, get_root_logger, set_random_seed) from mmdet.models import build_detector @@ -38,24 +36,6 @@ def parse_args(): return args -def get_train_dataset(cfg): - if isinstance(cfg.data.train['ann_file'], list) or isinstance(cfg.data.train['ann_file'], tuple): - ann_files = cfg.data.train['ann_file'] - train_datasets = [] - for ann_file in ann_files: - data_info = copy.deepcopy(cfg.data.train) - data_info['ann_file'] = ann_file - train_dset = obj_from_dict(data_info, datasets) - train_datasets.append(train_dset) - if len(train_datasets) > 1: - train_dataset = ConcatDataset(train_datasets) - else: - train_dataset = train_datasets[0] - else: - train_dataset = obj_from_dict(cfg.data.train, datasets) - return train_dataset - - def main(): args = parse_args() @@ -87,7 +67,7 @@ def main(): model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) - train_dataset = get_train_dataset(cfg) + train_dataset = datasets.get_dataset(cfg.data.train) train_detector( model, train_dataset, -- GitLab