Skip to content
Snippets Groups Projects
Commit 7cbdbc78 authored by wangg12's avatar wangg12
Browse files

move the function to datasets.utils

parent 7906bd20
No related branches found
No related tags found
No related merge requests found
from .custom import CustomDataset from .custom import CustomDataset
from .coco import CocoDataset from .coco import CocoDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
__all__ = [ __all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset', 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset',
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann' 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'get_dataset'
] ]
from collections import Sequence from collections import Sequence
import copy
import mmcv import mmcv
from mmcv.runner import obj_from_dict
import torch import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from .concat_dataset import ConcatDataset
from .. import datasets
def to_tensor(data): def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`. """Convert objects of various python types to :obj:`torch.Tensor`.
...@@ -67,3 +69,21 @@ def show_ann(coco, img, ann_info): ...@@ -67,3 +69,21 @@ def show_ann(coco, img, ann_info):
plt.axis('off') plt.axis('off')
coco.showAnns(ann_info) coco.showAnns(ann_info)
plt.show() plt.show()
def get_dataset(data_cfg):
if isinstance(data_cfg['ann_file'], list) or isinstance(data_cfg['ann_file'], tuple):
ann_files = data_cfg['ann_file']
dsets = []
for ann_file in ann_files:
data_info = copy.deepcopy(data_cfg)
data_info['ann_file'] = ann_file
dset = obj_from_dict(data_info, datasets)
dsets.append(dset)
if len(dsets) > 1:
dset = ConcatDataset(dsets)
else:
dset = dsets[0]
else:
dset = obj_from_dict(data_cfg, datasets)
return dset
\ No newline at end of file
from __future__ import division from __future__ import division
import argparse import argparse
import copy
from mmcv import Config from mmcv import Config
from mmcv.runner import obj_from_dict from mmcv.runner import obj_from_dict
from mmdet import datasets, __version__ from mmdet import datasets, __version__
from mmdet.datasets import ConcatDataset
from mmdet.apis import (train_detector, init_dist, get_root_logger, from mmdet.apis import (train_detector, init_dist, get_root_logger,
set_random_seed) set_random_seed)
from mmdet.models import build_detector from mmdet.models import build_detector
...@@ -38,24 +36,6 @@ def parse_args(): ...@@ -38,24 +36,6 @@ def parse_args():
return args return args
def get_train_dataset(cfg):
if isinstance(cfg.data.train['ann_file'], list) or isinstance(cfg.data.train['ann_file'], tuple):
ann_files = cfg.data.train['ann_file']
train_datasets = []
for ann_file in ann_files:
data_info = copy.deepcopy(cfg.data.train)
data_info['ann_file'] = ann_file
train_dset = obj_from_dict(data_info, datasets)
train_datasets.append(train_dset)
if len(train_datasets) > 1:
train_dataset = ConcatDataset(train_datasets)
else:
train_dataset = train_datasets[0]
else:
train_dataset = obj_from_dict(cfg.data.train, datasets)
return train_dataset
def main(): def main():
args = parse_args() args = parse_args()
...@@ -87,7 +67,7 @@ def main(): ...@@ -87,7 +67,7 @@ def main():
model = build_detector( model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
train_dataset = get_train_dataset(cfg) train_dataset = datasets.get_dataset(cfg.data.train)
train_detector( train_detector(
model, model,
train_dataset, train_dataset,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment