diff --git a/README.md b/README.md index 4eb1f1bf243071f99cf583dc8aee8ce1be2fbbb3..47df2d4c675b5fc3eea2e1e093dfa64e81512e66 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,10 @@ This project is released under the [Apache 2.0 license](LICENSE). ## Updates +v0.5.2 (21/10/2018) +- Add support for custom datasets. +- Add a script to convert PASCAL VOC annotations to the expected format. + v0.5.1 (20/10/2018) - Add BBoxAssigner and BBoxSampler, the `train_cfg` field in config files are restructured. - `ConvFCRoIHead` / `SharedFCRoIHead` are renamed to `ConvFCBBoxHead` / `SharedFCBBoxHead` for consistency. @@ -209,6 +213,48 @@ Expected results in WORK_DIR: > 1. We recommend using distributed training with NCCL2 even on a single machine, which is faster. Non-distributed training is for debugging or other purposes. > 2. The default learning rate is for 8 GPUs. If you use less or more than 8 GPUs, you need to set the learning rate proportional to the GPU num. E.g., modify lr to 0.01 for 4 GPUs or 0.04 for 16 GPUs. +### Train on custom datasets + +We define a simple annotation format. + +The annotation of a dataset is a list of dict, each dict corresponds to an image. +There are 3 field `filename` (relative path), `width`, `height` for testing, +and an additional field `ann` for training. `ann` is also a dict containing at least 2 fields: +`bboxes` and `labels`, both of which are numpy arrays. Some datasets may provide +annotations like crowd/difficult/ignored bboxes, we use `bboxes_ignore` and `labels_ignore` +to cover them. + +Here is an example. +``` +[ + { + 'filename': 'a.jpg', + 'width': 1280, + 'height': 720, + 'ann': { + 'bboxes': <np.ndarray> (n, 4), + 'labels': <np.ndarray> (n, ), + 'bboxes_ignore': <np.ndarray> (k, 4), + 'labels_ignore': <np.ndarray> (k, 4) (optional field) + } + }, + ... +] +``` + +There are two ways to work with custom datasets. + +- online conversion + + You can write a new Dataset class inherited from `CustomDataset`, and overwrite two methods + `load_annotations(self, ann_file)` and `get_ann_info(self, idx)`, like [CocoDataset](mmdet/datasets/coco.py). + +- offline conversion + + You can convert the annotation format to the expected format above and save it to + a pickle file, like [pascal_voc.py](tools/convert_datasets/pascal_voc.py). + Then you can simply use `CustomDataset`. + ## Technical details Some implementation details and project structures are described in the [technical details](TECHNICAL_DETAILS.md). diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 425ea72535a144544f44ebe8b5d63dd31336a54c..75e07097756bd014bbef17294b6803aa83621fd1 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -1,8 +1,9 @@ +from .custom import CustomDataset from .coco import CocoDataset from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .utils import to_tensor, random_scale, show_ann __all__ = [ - 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', + 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann' ] diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index fbb14aa285a6fe33be6edaa395134c322f9daef4..4058db90be6b50336af66fa14270d7dd0d16f882 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -1,113 +1,42 @@ -import os.path as osp - -import mmcv import numpy as np -from mmcv.parallel import DataContainer as DC from pycocotools.coco import COCO -from torch.utils.data import Dataset -from .transforms import (ImageTransform, BboxTransform, MaskTransform, - Numpy2Tensor) -from .utils import to_tensor, show_ann, random_scale +from .custom import CustomDataset -class CocoDataset(Dataset): +class CocoDataset(CustomDataset): - def __init__(self, - ann_file, - img_prefix, - img_scale, - img_norm_cfg, - size_divisor=None, - proposal_file=None, - num_max_proposals=1000, - flip_ratio=0, - with_mask=True, - with_crowd=True, - with_label=True, - test_mode=False, - debug=False): - # path of the data file + def load_annotations(self, ann_file): self.coco = COCO(ann_file) - # filter images with no annotation during training - if not test_mode: - self.img_ids, self.img_infos = self._filter_imgs() - else: - self.img_ids = self.coco.getImgIds() - self.img_infos = [ - self.coco.loadImgs(idx)[0] for idx in self.img_ids - ] - assert len(self.img_ids) == len(self.img_infos) - # get the mapping from original category ids to labels self.cat_ids = self.coco.getCatIds() self.cat2label = { cat_id: i + 1 for i, cat_id in enumerate(self.cat_ids) } - # prefix of images path - self.img_prefix = img_prefix - # (long_edge, short_edge) or [(long1, short1), (long2, short2), ...] - self.img_scales = img_scale if isinstance(img_scale, - list) else [img_scale] - assert mmcv.is_list_of(self.img_scales, tuple) - # color channel order and normalize configs - self.img_norm_cfg = img_norm_cfg - # proposals - # TODO: revise _filter_imgs to be more flexible - if proposal_file is not None: - self.proposals = mmcv.load(proposal_file) - ori_ids = self.coco.getImgIds() - sorted_idx = [ori_ids.index(id) for id in self.img_ids] - self.proposals = [self.proposals[idx] for idx in sorted_idx] - else: - self.proposals = None - self.num_max_proposals = num_max_proposals - # flip ratio - self.flip_ratio = flip_ratio - assert flip_ratio >= 0 and flip_ratio <= 1 - # padding border to ensure the image size can be divided by - # size_divisor (used for FPN) - self.size_divisor = size_divisor - # with crowd or not, False when using RetinaNet - self.with_crowd = with_crowd - # with mask or not - self.with_mask = with_mask - # with label is False for RPN - self.with_label = with_label - # in test mode or not - self.test_mode = test_mode - # debug mode or not - self.debug = debug - - # set group flag for the sampler - self._set_group_flag() - # transforms - self.img_transform = ImageTransform( - size_divisor=self.size_divisor, **self.img_norm_cfg) - self.bbox_transform = BboxTransform() - self.mask_transform = MaskTransform() - self.numpy2tensor = Numpy2Tensor() - - def __len__(self): - return len(self.img_ids) + self.img_ids = self.coco.getImgIds() + img_infos = [] + for i in self.img_ids: + info = self.coco.loadImgs([i])[0] + info['filename'] = info['file_name'] + img_infos.append(info) + return img_infos + + def get_ann_info(self, idx): + img_id = self.img_infos[idx]['id'] + ann_ids = self.coco.getAnnIds(imgIds=[img_id]) + ann_info = self.coco.loadAnns(ann_ids) + return self._parse_ann_info(ann_info) def _filter_imgs(self, min_size=32): """Filter images too small or without ground truths.""" - img_ids = list(set([_['image_id'] for _ in self.coco.anns.values()])) - valid_ids = [] - img_infos = [] - for i in img_ids: - info = self.coco.loadImgs(i)[0] - if min(info['width'], info['height']) >= min_size: - valid_ids.append(i) - img_infos.append(info) - return valid_ids, img_infos - - def _load_ann_info(self, idx): - img_id = self.img_ids[idx] - ann_ids = self.coco.getAnnIds(imgIds=img_id) - ann_info = self.coco.loadAnns(ann_ids) - return ann_info + valid_inds = [] + ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) + for i, img_info in enumerate(self.img_infos): + if self.img_ids[i] not in ids_with_ann: + continue + if min(img_info['width'], img_info['height']) >= min_size: + valid_inds.append(i) + return valid_inds def _parse_ann_info(self, ann_info, with_mask=True): """Parse bbox and mask annotation. @@ -172,158 +101,3 @@ class CocoDataset(Dataset): ann['mask_polys'] = gt_mask_polys ann['poly_lens'] = gt_poly_lens return ann - - def _set_group_flag(self): - """Set flag according to image aspect ratio. - - Images with aspect ratio greater than 1 will be set as group 1, - otherwise group 0. - """ - self.flag = np.zeros(len(self.img_ids), dtype=np.uint8) - for i in range(len(self.img_ids)): - img_info = self.img_infos[i] - if img_info['width'] / img_info['height'] > 1: - self.flag[i] = 1 - - def _rand_another(self, idx): - pool = np.where(self.flag == self.flag[idx])[0] - return np.random.choice(pool) - - def __getitem__(self, idx): - if self.test_mode: - return self.prepare_test_img(idx) - while True: - img_info = self.img_infos[idx] - ann_info = self._load_ann_info(idx) - - # load image - img = mmcv.imread(osp.join(self.img_prefix, img_info['file_name'])) - if self.debug: - show_ann(self.coco, img, ann_info) - - # load proposals if necessary - if self.proposals is not None: - proposals = self.proposals[idx][:self.num_max_proposals] - # TODO: Handle empty proposals properly. Currently images with - # no proposals are just ignored, but they can be used for - # training in concept. - if len(proposals) == 0: - idx = self._rand_another(idx) - continue - if not (proposals.shape[1] == 4 or proposals.shape[1] == 5): - raise AssertionError( - 'proposals should have shapes (n, 4) or (n, 5), ' - 'but found {}'.format(proposals.shape)) - if proposals.shape[1] == 5: - scores = proposals[:, 4, None] - proposals = proposals[:, :4] - else: - scores = None - - ann = self._parse_ann_info(ann_info, self.with_mask) - gt_bboxes = ann['bboxes'] - gt_labels = ann['labels'] - gt_bboxes_ignore = ann['bboxes_ignore'] - # skip the image if there is no valid gt bbox - if len(gt_bboxes) == 0: - idx = self._rand_another(idx) - continue - - # apply transforms - flip = True if np.random.rand() < self.flip_ratio else False - img_scale = random_scale(self.img_scales) # sample a scale - img, img_shape, pad_shape, scale_factor = self.img_transform( - img, img_scale, flip) - if self.proposals is not None: - proposals = self.bbox_transform(proposals, img_shape, - scale_factor, flip) - proposals = np.hstack( - [proposals, scores]) if scores is not None else proposals - gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor, - flip) - gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape, - scale_factor, flip) - - if self.with_mask: - gt_masks = self.mask_transform(ann['masks'], pad_shape, - scale_factor, flip) - - ori_shape = (img_info['height'], img_info['width'], 3) - img_meta = dict( - ori_shape=ori_shape, - img_shape=img_shape, - pad_shape=pad_shape, - scale_factor=scale_factor, - flip=flip) - - data = dict( - img=DC(to_tensor(img), stack=True), - img_meta=DC(img_meta, cpu_only=True), - gt_bboxes=DC(to_tensor(gt_bboxes))) - if self.proposals is not None: - data['proposals'] = DC(to_tensor(proposals)) - if self.with_label: - data['gt_labels'] = DC(to_tensor(gt_labels)) - if self.with_crowd: - data['gt_bboxes_ignore'] = DC(to_tensor(gt_bboxes_ignore)) - if self.with_mask: - data['gt_masks'] = DC(gt_masks, cpu_only=True) - return data - - def prepare_test_img(self, idx): - """Prepare an image for testing (multi-scale and flipping)""" - img_info = self.img_infos[idx] - img = mmcv.imread(osp.join(self.img_prefix, img_info['file_name'])) - if self.proposals is not None: - proposal = self.proposals[idx][:self.num_max_proposals] - if not (proposal.shape[1] == 4 or proposal.shape[1] == 5): - raise AssertionError( - 'proposals should have shapes (n, 4) or (n, 5), ' - 'but found {}'.format(proposal.shape)) - else: - proposal = None - - def prepare_single(img, scale, flip, proposal=None): - _img, img_shape, pad_shape, scale_factor = self.img_transform( - img, scale, flip) - _img = to_tensor(_img) - _img_meta = dict( - ori_shape=(img_info['height'], img_info['width'], 3), - img_shape=img_shape, - pad_shape=pad_shape, - scale_factor=scale_factor, - flip=flip) - if proposal is not None: - if proposal.shape[1] == 5: - score = proposal[:, 4, None] - proposal = proposal[:, :4] - else: - score = None - _proposal = self.bbox_transform(proposal, img_shape, - scale_factor, flip) - _proposal = np.hstack( - [_proposal, score]) if score is not None else _proposal - _proposal = to_tensor(_proposal) - else: - _proposal = None - return _img, _img_meta, _proposal - - imgs = [] - img_metas = [] - proposals = [] - for scale in self.img_scales: - _img, _img_meta, _proposal = prepare_single( - img, scale, False, proposal) - imgs.append(_img) - img_metas.append(DC(_img_meta, cpu_only=True)) - proposals.append(_proposal) - if self.flip_ratio > 0: - _img, _img_meta, _proposal = prepare_single( - img, scale, True, proposal) - imgs.append(_img) - img_metas.append(DC(_img_meta, cpu_only=True)) - proposals.append(_proposal) - data = dict(img=imgs, img_meta=img_metas) - if self.proposals is not None: - data['proposals'] = proposals - return data diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..3640a83db9ac10504ea6076ceff88bf1cdfa3ec8 --- /dev/null +++ b/mmdet/datasets/custom.py @@ -0,0 +1,274 @@ +import os.path as osp + +import mmcv +import numpy as np +from mmcv.parallel import DataContainer as DC +from torch.utils.data import Dataset + +from .transforms import (ImageTransform, BboxTransform, MaskTransform, + Numpy2Tensor) +from .utils import to_tensor, random_scale + + +class CustomDataset(Dataset): + """Custom dataset for detection. + + Annotation format: + [ + { + 'filename': 'a.jpg', + 'width': 1280, + 'height': 720, + 'ann': { + 'bboxes': <np.ndarray> (n, 4), + 'labels': <np.ndarray> (n, ), + 'bboxes_ignore': <np.ndarray> (k, 4), + 'labels_ignore': <np.ndarray> (k, 4) (optional field) + } + }, + ... + ] + + The `ann` field is optional for testing. + """ + + def __init__(self, + ann_file, + img_prefix, + img_scale, + img_norm_cfg, + size_divisor=None, + proposal_file=None, + num_max_proposals=1000, + flip_ratio=0, + with_mask=True, + with_crowd=True, + with_label=True, + test_mode=False): + # load annotations (and proposals) + self.img_infos = self.load_annotations(ann_file) + if proposal_file is not None: + self.proposals = self.load_proposals(proposal_file) + else: + self.proposals = None + # filter images with no annotation during training + if not test_mode: + valid_inds = self._filter_imgs() + self.img_infos = [self.img_infos[i] for i in valid_inds] + if self.proposals is not None: + self.proposals = [self.proposals[i] for i in valid_inds] + + # prefix of images path + self.img_prefix = img_prefix + # (long_edge, short_edge) or [(long1, short1), (long2, short2), ...] + self.img_scales = img_scale if isinstance(img_scale, + list) else [img_scale] + assert mmcv.is_list_of(self.img_scales, tuple) + # normalization configs + self.img_norm_cfg = img_norm_cfg + + # max proposals per image + self.num_max_proposals = num_max_proposals + # flip ratio + self.flip_ratio = flip_ratio + assert flip_ratio >= 0 and flip_ratio <= 1 + # padding border to ensure the image size can be divided by + # size_divisor (used for FPN) + self.size_divisor = size_divisor + + # with mask or not (reserved field, takes no effect) + self.with_mask = with_mask + # some datasets provide bbox annotations as ignore/crowd/difficult, + # if `with_crowd` is True, then these info is returned. + self.with_crowd = with_crowd + # with label is False for RPN + self.with_label = with_label + # in test mode or not + self.test_mode = test_mode + + # set group flag for the sampler + if not self.test_mode: + self._set_group_flag() + # transforms + self.img_transform = ImageTransform( + size_divisor=self.size_divisor, **self.img_norm_cfg) + self.bbox_transform = BboxTransform() + self.mask_transform = MaskTransform() + self.numpy2tensor = Numpy2Tensor() + + def __len__(self): + return len(self.img_infos) + + def load_annotations(self, ann_file): + return mmcv.load(ann_file) + + def load_proposals(self, proposal_file): + return mmcv.load(proposal_file) + + def get_ann_info(self, idx): + return self.img_infos[idx]['ann'] + + def _filter_imgs(self, min_size=32): + """Filter images too small.""" + valid_inds = [] + for i, img_info in enumerate(self.img_infos): + if min(img_info['width'], img_info['height']) >= min_size: + valid_inds.append(i) + return valid_inds + + def _set_group_flag(self): + """Set flag according to image aspect ratio. + + Images with aspect ratio greater than 1 will be set as group 1, + otherwise group 0. + """ + self.flag = np.zeros(len(self), dtype=np.uint8) + for i in range(len(self)): + img_info = self.img_infos[i] + if img_info['width'] / img_info['height'] > 1: + self.flag[i] = 1 + + def _rand_another(self, idx): + pool = np.where(self.flag == self.flag[idx])[0] + return np.random.choice(pool) + + def __getitem__(self, idx): + if self.test_mode: + return self.prepare_test_img(idx) + while True: + data = self.prepare_train_img(idx) + if data is None: + idx = self._rand_another(idx) + continue + return data + + def prepare_train_img(self, idx): + img_info = self.img_infos[idx] + # load image + img = mmcv.imread(osp.join(self.img_prefix, img_info['filename'])) + # load proposals if necessary + if self.proposals is not None: + proposals = self.proposals[idx][:self.num_max_proposals] + # TODO: Handle empty proposals properly. Currently images with + # no proposals are just ignored, but they can be used for + # training in concept. + if len(proposals) == 0: + return None + if not (proposals.shape[1] == 4 or proposals.shape[1] == 5): + raise AssertionError( + 'proposals should have shapes (n, 4) or (n, 5), ' + 'but found {}'.format(proposals.shape)) + if proposals.shape[1] == 5: + scores = proposals[:, 4, None] + proposals = proposals[:, :4] + else: + scores = None + + ann = self.get_ann_info(idx) + gt_bboxes = ann['bboxes'] + gt_labels = ann['labels'] + if self.with_crowd: + gt_bboxes_ignore = ann['bboxes_ignore'] + + # skip the image if there is no valid gt bbox + if len(gt_bboxes) == 0: + return None + + # apply transforms + flip = True if np.random.rand() < self.flip_ratio else False + img_scale = random_scale(self.img_scales) # sample a scale + img, img_shape, pad_shape, scale_factor = self.img_transform( + img, img_scale, flip) + if self.proposals is not None: + proposals = self.bbox_transform(proposals, img_shape, scale_factor, + flip) + proposals = np.hstack( + [proposals, scores]) if scores is not None else proposals + gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor, + flip) + if self.with_crowd: + gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape, + scale_factor, flip) + if self.with_mask: + gt_masks = self.mask_transform(ann['masks'], pad_shape, + scale_factor, flip) + + ori_shape = (img_info['height'], img_info['width'], 3) + img_meta = dict( + ori_shape=ori_shape, + img_shape=img_shape, + pad_shape=pad_shape, + scale_factor=scale_factor, + flip=flip) + + data = dict( + img=DC(to_tensor(img), stack=True), + img_meta=DC(img_meta, cpu_only=True), + gt_bboxes=DC(to_tensor(gt_bboxes))) + if self.proposals is not None: + data['proposals'] = DC(to_tensor(proposals)) + if self.with_label: + data['gt_labels'] = DC(to_tensor(gt_labels)) + if self.with_crowd: + data['gt_bboxes_ignore'] = DC(to_tensor(gt_bboxes_ignore)) + if self.with_mask: + data['gt_masks'] = DC(gt_masks, cpu_only=True) + return data + + def prepare_test_img(self, idx): + """Prepare an image for testing (multi-scale and flipping)""" + img_info = self.img_infos[idx] + img = mmcv.imread(osp.join(self.img_prefix, img_info['filename'])) + if self.proposals is not None: + proposal = self.proposals[idx][:self.num_max_proposals] + if not (proposal.shape[1] == 4 or proposal.shape[1] == 5): + raise AssertionError( + 'proposals should have shapes (n, 4) or (n, 5), ' + 'but found {}'.format(proposal.shape)) + else: + proposal = None + + def prepare_single(img, scale, flip, proposal=None): + _img, img_shape, pad_shape, scale_factor = self.img_transform( + img, scale, flip) + _img = to_tensor(_img) + _img_meta = dict( + ori_shape=(img_info['height'], img_info['width'], 3), + img_shape=img_shape, + pad_shape=pad_shape, + scale_factor=scale_factor, + flip=flip) + if proposal is not None: + if proposal.shape[1] == 5: + score = proposal[:, 4, None] + proposal = proposal[:, :4] + else: + score = None + _proposal = self.bbox_transform(proposal, img_shape, + scale_factor, flip) + _proposal = np.hstack( + [_proposal, score]) if score is not None else _proposal + _proposal = to_tensor(_proposal) + else: + _proposal = None + return _img, _img_meta, _proposal + + imgs = [] + img_metas = [] + proposals = [] + for scale in self.img_scales: + _img, _img_meta, _proposal = prepare_single( + img, scale, False, proposal) + imgs.append(_img) + img_metas.append(DC(_img_meta, cpu_only=True)) + proposals.append(_proposal) + if self.flip_ratio > 0: + _img, _img_meta, _proposal = prepare_single( + img, scale, True, proposal) + imgs.append(_img) + img_metas.append(DC(_img_meta, cpu_only=True)) + proposals.append(_proposal) + data = dict(img=imgs, img_meta=img_metas) + if self.proposals is not None: + data['proposals'] = proposals + return data diff --git a/mmdet/datasets/loader/build_loader.py b/mmdet/datasets/loader/build_loader.py index 761d9aea1884c8741a6653d5e9405ff5acc530a9..3e10e2399d5a20c30500e9a28db662eb78211c84 100644 --- a/mmdet/datasets/loader/build_loader.py +++ b/mmdet/datasets/loader/build_loader.py @@ -25,13 +25,13 @@ def build_dataloader(dataset, batch_size = imgs_per_gpu num_workers = workers_per_gpu else: - sampler = GroupSampler(dataset, imgs_per_gpu) + if not kwargs.get('shuffle', True): + sampler = None + else: + sampler = GroupSampler(dataset, imgs_per_gpu) batch_size = num_gpus * imgs_per_gpu num_workers = num_gpus * workers_per_gpu - if not kwargs.get('shuffle', True): - sampler = None - data_loader = DataLoader( dataset, batch_size=batch_size, diff --git a/mmdet/models/bbox_heads/convfc_bbox_head.py b/mmdet/models/bbox_heads/convfc_bbox_head.py index 4de176fb109a0dafe054d999a048f783b68f7cd4..2471b58233ebbf5c08788127a95c66fe252b7181 100644 --- a/mmdet/models/bbox_heads/convfc_bbox_head.py +++ b/mmdet/models/bbox_heads/convfc_bbox_head.py @@ -11,7 +11,7 @@ class ConvFCBBoxHead(BBoxHead): /-> cls convs -> cls fcs -> cls shared convs -> shared fcs \-> reg convs -> reg fcs -> reg - """ + """ # noqa: W605 def __init__(self, num_shared_convs=0, diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py index 064ee0e7d4ea77cddfa7bae5bc329e0422ba8fa2..7330010a021d65fd6a2fe2485b2985e4f96e0f7b 100644 --- a/mmdet/models/detectors/two_stage.py +++ b/mmdet/models/detectors/two_stage.py @@ -65,6 +65,9 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, if self.with_bbox: self.bbox_roi_extractor.init_weights() self.bbox_head.init_weights() + if self.with_mask: + self.mask_roi_extractor.init_weights() + self.mask_head.init_weights() def extract_feat(self, img): x = self.backbone(img) diff --git a/mmdet/models/rpn_heads/rpn_head.py b/mmdet/models/rpn_heads/rpn_head.py index 61e6e199ac0407bd23226701e3117c02ec16171d..ad06e40c8fcee1b2991029d8121b9ffad54cbf32 100644 --- a/mmdet/models/rpn_heads/rpn_head.py +++ b/mmdet/models/rpn_heads/rpn_head.py @@ -30,7 +30,7 @@ class RPNHead(nn.Module): target_stds (Iterable): Std values of regression targets. use_sigmoid_cls (bool): Whether to use sigmoid loss for classification. (softmax by default) - """ + """ # noqa: W605 def __init__(self, in_channels, diff --git a/mmdet/ops/nms/nms_wrapper.py b/mmdet/ops/nms/nms_wrapper.py index 43d5e5c6e5c038467f2084d46d85b97bb2a943f1..3978773b842a8236dedc8b942ad1d3100061813c 100644 --- a/mmdet/ops/nms/nms_wrapper.py +++ b/mmdet/ops/nms/nms_wrapper.py @@ -9,7 +9,9 @@ from .cpu_soft_nms import cpu_soft_nms def nms(dets, thresh, device_id=None): """Dispatch to either CPU or GPU NMS implementations.""" + tensor_device = None if isinstance(dets, torch.Tensor): + tensor_device = dets.device if dets.is_cuda: device_id = dets.get_device() dets = dets.detach().cpu().numpy() @@ -21,8 +23,8 @@ def nms(dets, thresh, device_id=None): inds = (gpu_nms(dets, thresh, device_id=device_id) if device_id is not None else cpu_nms(dets, thresh)) - if isinstance(dets, torch.Tensor): - return dets.new_tensor(inds, dtype=torch.long) + if tensor_device: + return torch.Tensor(inds).long().to(tensor_device) else: return np.array(inds, dtype=np.int) diff --git a/setup.py b/setup.py index 1803b7344593d95e37b868c77cd8b7352b1487fc..911519597e6e3b50bc4e5e8aa3a61461d9947daf 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ def readme(): MAJOR = 0 MINOR = 5 -PATCH = 1 +PATCH = 2 SUFFIX = '' SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX) diff --git a/tools/convert_datasets/pascal_voc.py b/tools/convert_datasets/pascal_voc.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb5cb4b7080f134287494f7f0283bed42b351cb --- /dev/null +++ b/tools/convert_datasets/pascal_voc.py @@ -0,0 +1,140 @@ +import argparse +import os.path as osp +import xml.etree.ElementTree as ET + +import mmcv +import numpy as np + +from mmdet.core import voc_classes + +label_ids = {name: i + 1 for i, name in enumerate(voc_classes())} + + +def parse_xml(args): + xml_path, img_path = args + tree = ET.parse(xml_path) + root = tree.getroot() + size = root.find('size') + w = int(size.find('width').text) + h = int(size.find('height').text) + bboxes = [] + labels = [] + bboxes_ignore = [] + labels_ignore = [] + for obj in root.findall('object'): + name = obj.find('name').text + label = label_ids[name] + difficult = int(obj.find('difficult').text) + bnd_box = obj.find('bndbox') + bbox = [ + int(bnd_box.find('xmin').text), + int(bnd_box.find('ymin').text), + int(bnd_box.find('xmax').text), + int(bnd_box.find('ymax').text) + ] + if difficult: + bboxes_ignore.append(bbox) + labels_ignore.append(label) + else: + bboxes.append(bbox) + labels.append(label) + if not bboxes: + bboxes = np.zeros((0, 4)) + labels = np.zeros((0, )) + else: + bboxes = np.array(bboxes, ndmin=2) - 1 + labels = np.array(labels) + if not bboxes_ignore: + bboxes_ignore = np.zeros((0, 4)) + labels_ignore = np.zeros((0, )) + else: + bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 + labels_ignore = np.array(labels_ignore) + annotation = { + 'filename': img_path, + 'width': w, + 'height': h, + 'ann': { + 'bboxes': bboxes.astype(np.float32), + 'labels': labels.astype(np.int64), + 'bboxes_ignore': bboxes_ignore.astype(np.float32), + 'labels_ignore': labels_ignore.astype(np.int64) + } + } + return annotation + + +def cvt_annotations(devkit_path, years, split, out_file): + if not isinstance(years, list): + years = [years] + annotations = [] + for year in years: + filelist = osp.join(devkit_path, 'VOC{}/ImageSets/Main/{}.txt'.format( + year, split)) + if not osp.isfile(filelist): + print('filelist does not exist: {}, skip voc{} {}'.format( + filelist, year, split)) + return + img_names = mmcv.list_from_file(filelist) + xml_paths = [ + osp.join(devkit_path, 'VOC{}/Annotations/{}.xml'.format( + year, img_name)) for img_name in img_names + ] + img_paths = [ + 'VOC{}/JPEGImages/{}.jpg'.format(year, img_name) + for img_name in img_names + ] + part_annotations = mmcv.track_progress(parse_xml, + list(zip(xml_paths, img_paths))) + annotations.extend(part_annotations) + mmcv.dump(annotations, out_file) + return annotations + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert PASCAL VOC annotations to mmdetection format') + parser.add_argument('devkit_path', help='pascal voc devkit path') + parser.add_argument('-o', '--out-dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + devkit_path = args.devkit_path + out_dir = args.out_dir if args.out_dir else devkit_path + mmcv.mkdir_or_exist(out_dir) + + years = [] + if osp.isdir(osp.join(devkit_path, 'VOC2007')): + years.append('2007') + if osp.isdir(osp.join(devkit_path, 'VOC2012')): + years.append('2012') + if '2007' in years and '2012' in years: + years.append(['2007', '2012']) + if not years: + raise IOError('The devkit path {} contains neither "VOC2007" nor ' + '"VOC2012" subfolder'.format(devkit_path)) + for year in years: + if year == '2007': + prefix = 'voc07' + elif year == '2012': + prefix = 'voc12' + elif year == ['2007', '2012']: + prefix = 'voc0712' + for split in ['train', 'val', 'trainval']: + dataset_name = prefix + '_' + split + print('processing {} ...'.format(dataset_name)) + cvt_annotations(devkit_path, year, split, + osp.join(out_dir, dataset_name + '.pkl')) + if not isinstance(year, list): + dataset_name = prefix + '_test' + print('processing {} ...'.format(dataset_name)) + cvt_annotations(devkit_path, year, 'test', + osp.join(out_dir, dataset_name + '.pkl')) + print('Done!') + + +if __name__ == '__main__': + main()