From 763decc4a2b0e9d85e3858c5870b4f6b50fa6ac4 Mon Sep 17 00:00:00 2001
From: pangjm <pjmzju@gmail.com>
Date: Wed, 10 Oct 2018 14:09:13 +0800
Subject: [PATCH] add fast rcnn api & fix minor bugs

---
 mmdet/models/detectors/__init__.py  |  3 ++-
 mmdet/models/detectors/fast_rcnn.py | 25 +++++++++++++++++++++++++
 mmdet/models/detectors/two_stage.py |  3 ++-
 3 files changed, 29 insertions(+), 2 deletions(-)
 create mode 100644 mmdet/models/detectors/fast_rcnn.py

diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
index b8914c1..29a64dd 100644
--- a/mmdet/models/detectors/__init__.py
+++ b/mmdet/models/detectors/__init__.py
@@ -1,6 +1,7 @@
 from .base import BaseDetector
 from .rpn import RPN
+from .faster_rcnn import FastRCNN
 from .faster_rcnn import FasterRCNN
 from .mask_rcnn import MaskRCNN
 
-__all__ = ['BaseDetector', 'RPN', 'FasterRCNN', 'MaskRCNN']
+__all__ = ['BaseDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'MaskRCNN']
diff --git a/mmdet/models/detectors/fast_rcnn.py b/mmdet/models/detectors/fast_rcnn.py
new file mode 100644
index 0000000..0dbf17a
--- /dev/null
+++ b/mmdet/models/detectors/fast_rcnn.py
@@ -0,0 +1,25 @@
+from .two_stage import TwoStageDetector
+
+
+class FastRCNN(TwoStageDetector):
+
+    def __init__(self,
+                 backbone,
+                 neck,
+                 bbox_roi_extractor,
+                 bbox_head,
+                 train_cfg,
+                 test_cfg,
+                 mask_roi_extractor=None,
+                 mask_head=None,
+                 pretrained=None):
+        super(FastRCNN, self).__init__(
+                    backbone=backbone,
+                    neck=neck,
+                    bbox_roi_extractor=bbox_roi_extractor,
+                    bbox_head=bbox_head,
+                    train_cfg=train_cfg,
+                    test_cfg=test_cfg,
+                    mask_roi_extractor=mask_roi_extractor,
+                    mask_head=mask_head,
+                    pretrained=pretrained)
diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py
index 8573d83..3cd6838 100644
--- a/mmdet/models/detectors/two_stage.py
+++ b/mmdet/models/detectors/two_stage.py
@@ -146,7 +146,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
         x = self.extract_feat(img)
 
         proposal_list = self.simple_test_rpn(
-            x, img_meta, self.test_cfg.rpn) if proposals is None else proposals
+            x, img_meta,
+            self.test_cfg.rpn) if proposals is None else proposals[0]
 
         det_bboxes, det_labels = self.simple_test_bboxes(
             x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
-- 
GitLab