diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index b478acffebb30d475042259a9b97c8307627b5ec..dda1a1eb4a514b93254623693753ecf8f3839c12 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -5,6 +5,7 @@ from .utils import to_tensor, random_scale, show_ann, get_dataset from .concat_dataset import ConcatDataset __all__ = [ - 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset', - 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'get_dataset' + 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', + 'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale', + 'show_ann', 'get_dataset' ] diff --git a/mmdet/datasets/concat_dataset.py b/mmdet/datasets/concat_dataset.py index 47e073a2bf3e83413c1eb0446c6d009445987b78..605d112f628ba0fc2f411574324b552b59cd9862 100644 --- a/mmdet/datasets/concat_dataset.py +++ b/mmdet/datasets/concat_dataset.py @@ -5,7 +5,7 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset class ConcatDataset(_ConcatDataset): """ - Same as torch.utils.data.dataset.ConcatDataset, but + Same as torch.utils.data.dataset.ConcatDataset, but concat the group flag for image aspect ratio. """ def __init__(self, datasets): @@ -13,7 +13,7 @@ class ConcatDataset(_ConcatDataset): flag: Images with aspect ratio greater than 1 will be set as group 1, otherwise group 0. """ - super(ConcatDataset, self).__init__(datasets) + super(ConcatDataset, self).__init__(datasets) if hasattr(datasets[0], 'flag'): flags = [] for i in range(0, len(datasets)): @@ -27,4 +27,3 @@ class ConcatDataset(_ConcatDataset): else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return dataset_idx, sample_idx - diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py index d13ae892d485946f122cc33516cfbf78367f7582..92b4e5d57b2ce28664298a51cfc85f450b5c63b3 100644 --- a/mmdet/datasets/utils.py +++ b/mmdet/datasets/utils.py @@ -9,6 +9,7 @@ import numpy as np from .concat_dataset import ConcatDataset from .. import datasets + def to_tensor(data): """Convert objects of various python types to :obj:`torch.Tensor`. @@ -72,7 +73,8 @@ def show_ann(coco, img, ann_info): def get_dataset(data_cfg): - if isinstance(data_cfg['ann_file'], list) or isinstance(data_cfg['ann_file'], tuple): + if isinstance(data_cfg['ann_file'], list) or \ + isinstance(data_cfg['ann_file'], tuple): ann_files = data_cfg['ann_file'] dsets = [] for ann_file in ann_files: @@ -81,9 +83,9 @@ def get_dataset(data_cfg): dset = obj_from_dict(data_info, datasets) dsets.append(dset) if len(dsets) > 1: - dset = ConcatDataset(dsets) + dset = ConcatDataset(dsets) else: dset = dsets[0] else: dset = obj_from_dict(data_cfg, datasets) - return dset \ No newline at end of file + return dset diff --git a/tools/train.py b/tools/train.py index 006fc1197f5ea4e1df866c05d8323070b08d51df..49c46f05f88393f24f14a5f3a2f2fc4f6b58bbaf 100644 --- a/tools/train.py +++ b/tools/train.py @@ -2,7 +2,6 @@ from __future__ import division import argparse from mmcv import Config -from mmcv.runner import obj_from_dict from mmdet import datasets, __version__ from mmdet.apis import (train_detector, init_dist, get_root_logger,