From 763decc4a2b0e9d85e3858c5870b4f6b50fa6ac4 Mon Sep 17 00:00:00 2001 From: pangjm <pjmzju@gmail.com> Date: Wed, 10 Oct 2018 14:09:13 +0800 Subject: [PATCH] add fast rcnn api & fix minor bugs --- mmdet/models/detectors/__init__.py | 3 ++- mmdet/models/detectors/fast_rcnn.py | 25 +++++++++++++++++++++++++ mmdet/models/detectors/two_stage.py | 3 ++- 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 mmdet/models/detectors/fast_rcnn.py diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index b8914c1..29a64dd 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -1,6 +1,7 @@ from .base import BaseDetector from .rpn import RPN +from .faster_rcnn import FastRCNN from .faster_rcnn import FasterRCNN from .mask_rcnn import MaskRCNN -__all__ = ['BaseDetector', 'RPN', 'FasterRCNN', 'MaskRCNN'] +__all__ = ['BaseDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'MaskRCNN'] diff --git a/mmdet/models/detectors/fast_rcnn.py b/mmdet/models/detectors/fast_rcnn.py new file mode 100644 index 0000000..0dbf17a --- /dev/null +++ b/mmdet/models/detectors/fast_rcnn.py @@ -0,0 +1,25 @@ +from .two_stage import TwoStageDetector + + +class FastRCNN(TwoStageDetector): + + def __init__(self, + backbone, + neck, + bbox_roi_extractor, + bbox_head, + train_cfg, + test_cfg, + mask_roi_extractor=None, + mask_head=None, + pretrained=None): + super(FastRCNN, self).__init__( + backbone=backbone, + neck=neck, + bbox_roi_extractor=bbox_roi_extractor, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + mask_roi_extractor=mask_roi_extractor, + mask_head=mask_head, + pretrained=pretrained) diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py index 8573d83..3cd6838 100644 --- a/mmdet/models/detectors/two_stage.py +++ b/mmdet/models/detectors/two_stage.py @@ -146,7 +146,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, x = self.extract_feat(img) proposal_list = self.simple_test_rpn( - x, img_meta, self.test_cfg.rpn) if proposals is None else proposals + x, img_meta, + self.test_cfg.rpn) if proposals is None else proposals[0] det_bboxes, det_labels = self.simple_test_bboxes( x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale) -- GitLab