diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index 93a05c8594eb70e34c9291117f32df42b408bd40..d1b0fce1283b012072e7fb1f864313135eeac940 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -17,6 +17,18 @@ class BaseDetector(nn.Module): def __init__(self): super(BaseDetector, self).__init__() + @property + def with_neck(self): + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_bbox(self): + return hasattr(self, 'bbox_head') and self.bbox_head is not None + + @property + def with_mask(self): + return hasattr(self, 'mask_head') and self.mask_head is not None + @abstractmethod def extract_feat(self, imgs): pass diff --git a/mmdet/models/detectors/rpn.py b/mmdet/models/detectors/rpn.py index a291006fdd58ac4f0f2a970195bc41ae51f73126..9d700fe3e3c3af357256b36f1582c6a8c7249580 100644 --- a/mmdet/models/detectors/rpn.py +++ b/mmdet/models/detectors/rpn.py @@ -26,13 +26,13 @@ class RPN(BaseDetector, RPNTestMixin): def init_weights(self, pretrained=None): super(RPN, self).init_weights(pretrained) self.backbone.init_weights(pretrained=pretrained) - if self.neck is not None: + if self.with_neck: self.neck.init_weights() self.rpn_head.init_weights() def extract_feat(self, img): x = self.backbone(img) - if self.neck is not None: + if self.with_neck: x = self.neck(x) return x diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py index c0e81eec40bc9bc6e0f3f55d74e97fecc7f2a4bc..f69db22ced983e2b41eb693daa5ec1099f6f4a55 100644 --- a/mmdet/models/detectors/two_stage.py +++ b/mmdet/models/detectors/two_stage.py @@ -25,23 +25,19 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, self.backbone = builder.build_backbone(backbone) if neck is not None: - self.with_neck = True self.neck = builder.build_neck(neck) else: raise NotImplementedError - self.with_rpn = True if rpn_head is not None else False - if self.with_rpn: + if rpn_head is not None: self.rpn_head = builder.build_rpn_head(rpn_head) - self.with_bbox = True if bbox_head is not None else False - if self.with_bbox: + if bbox_head is not None: self.bbox_roi_extractor = builder.build_roi_extractor( bbox_roi_extractor) self.bbox_head = builder.build_bbox_head(bbox_head) - self.with_mask = True if mask_head is not None else False - if self.with_mask: + if mask_head is not None: self.mask_roi_extractor = builder.build_roi_extractor( mask_roi_extractor) self.mask_head = builder.build_mask_head(mask_head) @@ -51,6 +47,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, self.init_weights(pretrained=pretrained) + @property + def with_rpn(self): + return hasattr(self, 'rpn_head') and self.rpn_head is not None + def init_weights(self, pretrained=None): super(TwoStageDetector, self).init_weights(pretrained) self.backbone.init_weights(pretrained=pretrained)