diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index 3cd0a6d5ca20dbeba11f96135b570635348c74d9..9049f1af9703a5d97a6a6f53c33eac3190468c97 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -53,8 +53,14 @@ class CocoDataset(Dataset): # color channel order and normalize configs self.img_norm_cfg = img_norm_cfg # proposals - self.proposals = mmcv.load( - proposal_file) if proposal_file is not None else None + # TODO: revise _filter_imgs to be more flexible + if proposal_file is not None: + self.proposals = mmcv.load(proposal_file) + ori_ids = self.coco.getImgIds() + sorted_idx = [ori_ids.index(id) for id in self.img_ids] + self.proposals = [self.proposals[idx] for idx in sorted_idx] + else: + self.proposals = None self.num_max_proposals = num_max_proposals # flip ratio self.flip_ratio = flip_ratio @@ -271,7 +277,8 @@ class CocoDataset(Dataset): scale_factor=scale_factor, flip=flip) if proposal is not None: - _proposal = self.bbox_transform(proposal, scale_factor, flip) + _proposal = self.bbox_transform(proposal, img_shape, + scale_factor, flip) _proposal = to_tensor(_proposal) else: _proposal = None diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index b8914c1e5d3c834a1373b2a2e8360183a41de4da..c911d1723d161ae18e78198a53d47921d7937012 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -1,6 +1,7 @@ from .base import BaseDetector from .rpn import RPN +from .fast_rcnn import FastRCNN from .faster_rcnn import FasterRCNN from .mask_rcnn import MaskRCNN -__all__ = ['BaseDetector', 'RPN', 'FasterRCNN', 'MaskRCNN'] +__all__ = ['BaseDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'MaskRCNN'] diff --git a/mmdet/models/detectors/fast_rcnn.py b/mmdet/models/detectors/fast_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..fd80a87f69d67a2c77378c926a39c2ddb3208ac0 --- /dev/null +++ b/mmdet/models/detectors/fast_rcnn.py @@ -0,0 +1,46 @@ +from .two_stage import TwoStageDetector + + +class FastRCNN(TwoStageDetector): + + def __init__(self, + backbone, + neck, + bbox_roi_extractor, + bbox_head, + train_cfg, + test_cfg, + mask_roi_extractor=None, + mask_head=None, + pretrained=None): + super(FastRCNN, self).__init__( + backbone=backbone, + neck=neck, + bbox_roi_extractor=bbox_roi_extractor, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + mask_roi_extractor=mask_roi_extractor, + mask_head=mask_head, + pretrained=pretrained) + + def forward_test(self, imgs, img_metas, proposals, **kwargs): + for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError('{} must be a list, but got {}'.format( + name, type(var))) + + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError( + 'num of augmentations ({}) != num of image meta ({})'.format( + len(imgs), len(img_metas))) + # TODO: remove the restriction of imgs_per_gpu == 1 when prepared + imgs_per_gpu = imgs[0].size(0) + assert imgs_per_gpu == 1 + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], proposals[0], + **kwargs) + else: + return self.aug_test(imgs, img_metas, proposals, **kwargs) diff --git a/mmdet/models/detectors/test_mixins.py b/mmdet/models/detectors/test_mixins.py index 77ba244f1a3fa107bfb6828110eaa344f4a0ba8a..38136f47545c49d88253fee321c91f9408058ca9 100644 --- a/mmdet/models/detectors/test_mixins.py +++ b/mmdet/models/detectors/test_mixins.py @@ -135,6 +135,11 @@ class MaskTestMixin(object): ori_shape = img_metas[0][0]['ori_shape'] segm_result = self.mask_head.get_seg_masks( - merged_masks, det_bboxes, det_labels, self.test_cfg.rcnn, - ori_shape) + merged_masks, + det_bboxes, + det_labels, + self.test_cfg.rcnn, + ori_shape, + scale_factor=1.0, + rescale=False) return segm_result diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py index 8573d83215f120ba392a2f6b45cb9b6b93ca0519..b2f2839f93c1fc8e0e6451002f9724cef2c036a0 100644 --- a/mmdet/models/detectors/two_stage.py +++ b/mmdet/models/detectors/two_stage.py @@ -146,7 +146,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, x = self.extract_feat(img) proposal_list = self.simple_test_rpn( - x, img_meta, self.test_cfg.rpn) if proposals is None else proposals + x, img_meta, + self.test_cfg.rpn) if proposals is None else proposals det_bboxes, det_labels = self.simple_test_bboxes( x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale) diff --git a/tools/test.py b/tools/test.py index 3b1ce2d2e04859fdcce4c977556be89298d1953d..2552e7af78779104c853c880b94d40266dbb7d54 100644 --- a/tools/test.py +++ b/tools/test.py @@ -3,11 +3,11 @@ import argparse import torch import mmcv from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict -from mmcv.parallel import scatter, MMDataParallel +from mmcv.parallel import scatter, collate, MMDataParallel from mmdet import datasets from mmdet.core import results2json, coco_eval -from mmdet.datasets import collate, build_dataloader +from mmdet.datasets import build_dataloader from mmdet.models import build_detector, detectors