Skip to content
Snippets Groups Projects
Unverified Commit b8bcda67 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Use registry to manage datasets (#924)

* use registry to manage datasets

* bug fix for concat dataset

* update documentation to fit the new api
parent 2fb15316
No related branches found
No related tags found
No related merge requests found
Showing with 210 additions and 141 deletions
...@@ -150,8 +150,10 @@ In `mmdet/datasets/my_dataset.py`: ...@@ -150,8 +150,10 @@ In `mmdet/datasets/my_dataset.py`:
```python ```python
from .coco import CocoDataset from .coco import CocoDataset
from .registry import DATASETS
@DATASETS.register_module
class MyDataset(CocoDataset): class MyDataset(CocoDataset):
CLASSES = ('a', 'b', 'c', 'd', 'e') CLASSES = ('a', 'b', 'c', 'd', 'e')
...@@ -228,7 +230,7 @@ import torch.nn as nn ...@@ -228,7 +230,7 @@ import torch.nn as nn
from ..registry import BACKBONES from ..registry import BACKBONES
@BACKBONES.register @BACKBONES.register_module
class MobileNet(nn.Module): class MobileNet(nn.Module):
def __init__(self, arg1, arg2): def __init__(self, arg1, arg2):
......
...@@ -4,14 +4,15 @@ from .coco import CocoDataset ...@@ -4,14 +4,15 @@ from .coco import CocoDataset
from .voc import VOCDataset from .voc import VOCDataset
from .wider_face import WIDERFaceDataset from .wider_face import WIDERFaceDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann, get_dataset from .utils import to_tensor, random_scale, show_ann
from .concat_dataset import ConcatDataset from .dataset_wrappers import ConcatDataset, RepeatDataset
from .repeat_dataset import RepeatDataset
from .extra_aug import ExtraAugmentation from .extra_aug import ExtraAugmentation
from .registry import DATASETS
from .builder import build_dataset
__all__ = [ __all__ = [
'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler', 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler',
'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale', 'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale',
'show_ann', 'get_dataset', 'ConcatDataset', 'RepeatDataset', 'show_ann', 'ConcatDataset', 'RepeatDataset', 'ExtraAugmentation',
'ExtraAugmentation', 'WIDERFaceDataset' 'WIDERFaceDataset', 'DATASETS', 'build_dataset'
] ]
import copy
from mmdet.utils import build_from_cfg
from .dataset_wrappers import ConcatDataset, RepeatDataset
from .registry import DATASETS
def _concat_dataset(cfg):
ann_files = cfg['ann_file']
img_prefixes = cfg.get('img_prefix', None)
seg_prefixes = cfg.get('seg_prefixes', None)
proposal_files = cfg.get('proposal_file', None)
datasets = []
num_dset = len(ann_files)
for i in range(num_dset):
data_cfg = copy.deepcopy(cfg)
data_cfg['ann_file'] = ann_files[i]
if isinstance(img_prefixes, (list, tuple)):
data_cfg['img_prefix'] = img_prefixes[i]
if isinstance(seg_prefixes, (list, tuple)):
data_cfg['seg_prefix'] = seg_prefixes[i]
if isinstance(proposal_files, (list, tuple)):
data_cfg['proposal_file'] = proposal_files[i]
datasets.append(build_dataset(data_cfg))
return ConcatDataset(datasets)
def build_dataset(cfg):
if cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(build_dataset(cfg['dataset']), cfg['times'])
elif isinstance(cfg['ann_file'], (list, tuple)):
dataset = _concat_dataset(cfg)
else:
dataset = build_from_cfg(cfg, DATASETS)
return dataset
...@@ -2,8 +2,10 @@ import numpy as np ...@@ -2,8 +2,10 @@ import numpy as np
from pycocotools.coco import COCO from pycocotools.coco import COCO
from .custom import CustomDataset from .custom import CustomDataset
from .registry import DATASETS
@DATASETS.register_module
class CocoDataset(CustomDataset): class CocoDataset(CustomDataset):
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
......
...@@ -5,12 +5,14 @@ import numpy as np ...@@ -5,12 +5,14 @@ import numpy as np
from mmcv.parallel import DataContainer as DC from mmcv.parallel import DataContainer as DC
from torch.utils.data import Dataset from torch.utils.data import Dataset
from .registry import DATASETS
from .transforms import (ImageTransform, BboxTransform, MaskTransform, from .transforms import (ImageTransform, BboxTransform, MaskTransform,
SegMapTransform, Numpy2Tensor) SegMapTransform, Numpy2Tensor)
from .utils import to_tensor, random_scale from .utils import to_tensor, random_scale
from .extra_aug import ExtraAugmentation from .extra_aug import ExtraAugmentation
@DATASETS.register_module
class CustomDataset(Dataset): class CustomDataset(Dataset):
"""Custom dataset for detection. """Custom dataset for detection.
......
import numpy as np import numpy as np
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from .registry import DATASETS
@DATASETS.register_module
class ConcatDataset(_ConcatDataset): class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset. """A wrapper of concatenated dataset.
...@@ -20,3 +23,33 @@ class ConcatDataset(_ConcatDataset): ...@@ -20,3 +23,33 @@ class ConcatDataset(_ConcatDataset):
for i in range(0, len(datasets)): for i in range(0, len(datasets)):
flags.append(datasets[i].flag) flags.append(datasets[i].flag)
self.flag = np.concatenate(flags) self.flag = np.concatenate(flags)
@DATASETS.register_module
class RepeatDataset(object):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self.CLASSES = dataset.CLASSES
if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, times)
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx % self._ori_len]
def __len__(self):
return self.times * self._ori_len
from mmdet.utils import Registry
DATASETS = Registry('dataset')
import numpy as np
class RepeatDataset(object):
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self.CLASSES = dataset.CLASSES
if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, times)
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx % self._ori_len]
def __len__(self):
return self.times * self._ori_len
import copy
from collections import Sequence from collections import Sequence
import mmcv
from mmcv.runner import obj_from_dict
import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import mmcv
import numpy as np import numpy as np
from .concat_dataset import ConcatDataset import torch
from .repeat_dataset import RepeatDataset
from .. import datasets
def to_tensor(data): def to_tensor(data):
...@@ -72,45 +66,3 @@ def show_ann(coco, img, ann_info): ...@@ -72,45 +66,3 @@ def show_ann(coco, img, ann_info):
plt.axis('off') plt.axis('off')
coco.showAnns(ann_info) coco.showAnns(ann_info)
plt.show() plt.show()
def get_dataset(data_cfg):
if data_cfg['type'] == 'RepeatDataset':
return RepeatDataset(
get_dataset(data_cfg['dataset']), data_cfg['times'])
if isinstance(data_cfg['ann_file'], (list, tuple)):
ann_files = data_cfg['ann_file']
num_dset = len(ann_files)
else:
ann_files = [data_cfg['ann_file']]
num_dset = 1
if 'proposal_file' in data_cfg.keys():
if isinstance(data_cfg['proposal_file'], (list, tuple)):
proposal_files = data_cfg['proposal_file']
else:
proposal_files = [data_cfg['proposal_file']]
else:
proposal_files = [None] * num_dset
assert len(proposal_files) == num_dset
if isinstance(data_cfg['img_prefix'], (list, tuple)):
img_prefixes = data_cfg['img_prefix']
else:
img_prefixes = [data_cfg['img_prefix']] * num_dset
assert len(img_prefixes) == num_dset
dsets = []
for i in range(num_dset):
data_info = copy.deepcopy(data_cfg)
data_info['ann_file'] = ann_files[i]
data_info['proposal_file'] = proposal_files[i]
data_info['img_prefix'] = img_prefixes[i]
dset = obj_from_dict(data_info, datasets)
dsets.append(dset)
if len(dsets) > 1:
dset = ConcatDataset(dsets)
else:
dset = dsets[0]
return dset
from .registry import DATASETS
from .xml_style import XMLDataset from .xml_style import XMLDataset
@DATASETS.register_module
class VOCDataset(XMLDataset): class VOCDataset(XMLDataset):
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
......
...@@ -3,16 +3,18 @@ import xml.etree.ElementTree as ET ...@@ -3,16 +3,18 @@ import xml.etree.ElementTree as ET
import mmcv import mmcv
from .registry import DATASETS
from .xml_style import XMLDataset from .xml_style import XMLDataset
@DATASETS.register_module
class WIDERFaceDataset(XMLDataset): class WIDERFaceDataset(XMLDataset):
""" """
Reader for the WIDER Face dataset in PASCAL VOC format. Reader for the WIDER Face dataset in PASCAL VOC format.
Conversion scripts can be found in Conversion scripts can be found in
https://github.com/sovrasov/wider-face-pascal-voc-annotations https://github.com/sovrasov/wider-face-pascal-voc-annotations
""" """
CLASSES = ('face',) CLASSES = ('face', )
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(WIDERFaceDataset, self).__init__(**kwargs) super(WIDERFaceDataset, self).__init__(**kwargs)
...@@ -31,7 +33,10 @@ class WIDERFaceDataset(XMLDataset): ...@@ -31,7 +33,10 @@ class WIDERFaceDataset(XMLDataset):
height = int(size.find('height').text) height = int(size.find('height').text)
folder = root.find('folder').text folder = root.find('folder').text
img_infos.append( img_infos.append(
dict(id=img_id, filename=osp.join(folder, filename), dict(
width=width, height=height)) id=img_id,
filename=osp.join(folder, filename),
width=width,
height=height))
return img_infos return img_infos
...@@ -5,8 +5,10 @@ import mmcv ...@@ -5,8 +5,10 @@ import mmcv
import numpy as np import numpy as np
from .custom import CustomDataset from .custom import CustomDataset
from .registry import DATASETS
@DATASETS.register_module
class XMLDataset(CustomDataset): class XMLDataset(CustomDataset):
def __init__(self, min_size=None, **kwargs): def __init__(self, min_size=None, **kwargs):
......
import mmcv
from torch import nn from torch import nn
from mmdet.utils import build_from_cfg
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS, from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
LOSSES, DETECTORS) LOSSES, DETECTORS)
def _build_module(cfg, registry, default_args):
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
if obj_type not in registry.module_dict:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
obj_type = registry.module_dict[obj_type]
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
def build(cfg, registry, default_args=None): def build(cfg, registry, default_args=None):
if isinstance(cfg, list): if isinstance(cfg, list):
modules = [_build_module(cfg_, registry, default_args) for cfg_ in cfg] modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules) return nn.Sequential(*modules)
else: else:
return _build_module(cfg, registry, default_args) return build_from_cfg(cfg, registry, default_args)
def build_backbone(cfg): def build_backbone(cfg):
......
import torch.nn as nn from mmdet.utils import Registry
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def _register_module(self, module_class):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not issubclass(module_class, nn.Module):
raise TypeError(
'module must be a child of nn.Module, but got {}'.format(
module_class))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
BACKBONES = Registry('backbone') BACKBONES = Registry('backbone')
NECKS = Registry('neck') NECKS = Registry('neck')
......
from .registry import Registry, build_from_cfg
__all__ = ['Registry', 'build_from_cfg']
import inspect
import mmcv
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __repr__(self):
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
return self._module_dict.get(key, None)
def _register_module(self, module_class):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
obj_type = registry.get(obj_type)
if obj_type is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif not inspect.isclass(obj_type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
...@@ -12,7 +12,7 @@ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel ...@@ -12,7 +12,7 @@ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmdet.apis import init_dist from mmdet.apis import init_dist
from mmdet.core import results2json, coco_eval, wrap_fp16_model from mmdet.core import results2json, coco_eval, wrap_fp16_model
from mmdet.datasets import build_dataloader, get_dataset from mmdet.datasets import build_dataloader, build_dataset
from mmdet.models import build_detector from mmdet.models import build_detector
...@@ -147,7 +147,7 @@ def main(): ...@@ -147,7 +147,7 @@ def main():
# build the dataloader # build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed) # TODO: support multiple images per gpu (only minor changes are needed)
dataset = get_dataset(cfg.data.test) dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader( data_loader = build_dataloader(
dataset, dataset,
imgs_per_gpu=1, imgs_per_gpu=1,
......
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
from mmcv import Config from mmcv import Config
from mmdet import __version__ from mmdet import __version__
from mmdet.datasets import get_dataset from mmdet.datasets import build_dataset
from mmdet.apis import (train_detector, init_dist, get_root_logger, from mmdet.apis import (train_detector, init_dist, get_root_logger,
set_random_seed) set_random_seed)
from mmdet.models import build_detector from mmdet.models import build_detector
...@@ -75,7 +75,7 @@ def main(): ...@@ -75,7 +75,7 @@ def main():
model = build_detector( model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
train_dataset = get_dataset(cfg.data.train) train_dataset = build_dataset(cfg.data.train)
if cfg.checkpoint_config is not None: if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in # save mmdet version, config file content and class names in
# checkpoints as meta data # checkpoints as meta data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment