From f1d06cdc70a81fbeee219f7c99654c319131ab8a Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Wed, 13 Feb 2019 20:23:39 +0800
Subject: [PATCH] support gt_bboxes_ignore for anchor heads

---
 mmdet/core/anchor/anchor_target.py       | 11 ++++++++---
 mmdet/models/anchor_heads/anchor_head.py | 11 +++++++++--
 mmdet/models/anchor_heads/rpn_head.py    | 18 +++++++++++++++---
 mmdet/models/anchor_heads/ssd_head.py    | 11 +++++++++--
 mmdet/models/detectors/cascade_rcnn.py   |  5 +++--
 mmdet/models/detectors/rpn.py            |  9 +++++++--
 mmdet/models/detectors/single_stage.py   | 10 ++++++++--
 mmdet/models/detectors/two_stage.py      |  7 +++++--
 8 files changed, 64 insertions(+), 18 deletions(-)

diff --git a/mmdet/core/anchor/anchor_target.py b/mmdet/core/anchor/anchor_target.py
index 2dae831..7a5bf4e 100644
--- a/mmdet/core/anchor/anchor_target.py
+++ b/mmdet/core/anchor/anchor_target.py
@@ -11,6 +11,7 @@ def anchor_target(anchor_list,
                   target_means,
                   target_stds,
                   cfg,
+                  gt_bboxes_ignore_list=None,
                   gt_labels_list=None,
                   label_channels=1,
                   sampling=True,
@@ -41,6 +42,8 @@ def anchor_target(anchor_list,
         valid_flag_list[i] = torch.cat(valid_flag_list[i])
 
     # compute targets for each image
+    if gt_bboxes_ignore_list is None:
+        gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
     if gt_labels_list is None:
         gt_labels_list = [None for _ in range(num_imgs)]
     (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
@@ -49,6 +52,7 @@ def anchor_target(anchor_list,
          anchor_list,
          valid_flag_list,
          gt_bboxes_list,
+         gt_bboxes_ignore_list,
          gt_labels_list,
          img_metas,
          target_means=target_means,
@@ -90,6 +94,7 @@ def images_to_levels(target, num_level_anchors):
 def anchor_target_single(flat_anchors,
                          valid_flags,
                          gt_bboxes,
+                         gt_bboxes_ignore,
                          gt_labels,
                          img_meta,
                          target_means,
@@ -108,11 +113,11 @@ def anchor_target_single(flat_anchors,
 
     if sampling:
         assign_result, sampling_result = assign_and_sample(
-            anchors, gt_bboxes, None, None, cfg)
+            anchors, gt_bboxes, gt_bboxes_ignore, None, cfg)
     else:
         bbox_assigner = build_assigner(cfg.assigner)
-        assign_result = bbox_assigner.assign(anchors, gt_bboxes, None,
-                                             gt_labels)
+        assign_result = bbox_assigner.assign(anchors, gt_bboxes,
+                                             gt_bboxes_ignore, gt_labels)
         bbox_sampler = PseudoSampler()
         sampling_result = bbox_sampler.sample(assign_result, anchors,
                                               gt_bboxes)
diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py
index 57a7d64..d3ab2d2 100644
--- a/mmdet/models/anchor_heads/anchor_head.py
+++ b/mmdet/models/anchor_heads/anchor_head.py
@@ -169,8 +169,14 @@ class AnchorHead(nn.Module):
             avg_factor=num_total_samples)
         return loss_cls, loss_reg
 
-    def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
-             cfg):
+    def loss(self,
+             cls_scores,
+             bbox_preds,
+             gt_bboxes,
+             gt_labels,
+             img_metas,
+             cfg,
+             gt_bboxes_ignore=None):
         featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
         assert len(featmap_sizes) == len(self.anchor_generators)
 
@@ -186,6 +192,7 @@ class AnchorHead(nn.Module):
             self.target_means,
             self.target_stds,
             cfg,
+            gt_bboxes_ignore_list=gt_bboxes_ignore,
             gt_labels_list=gt_labels,
             label_channels=label_channels,
             sampling=sampling)
diff --git a/mmdet/models/anchor_heads/rpn_head.py b/mmdet/models/anchor_heads/rpn_head.py
index b913fc7..fe9d5c3 100644
--- a/mmdet/models/anchor_heads/rpn_head.py
+++ b/mmdet/models/anchor_heads/rpn_head.py
@@ -34,9 +34,21 @@ class RPNHead(AnchorHead):
         rpn_bbox_pred = self.rpn_reg(x)
         return rpn_cls_score, rpn_bbox_pred
 
-    def loss(self, cls_scores, bbox_preds, gt_bboxes, img_metas, cfg):
-        losses = super(RPNHead, self).loss(cls_scores, bbox_preds, gt_bboxes,
-                                           None, img_metas, cfg)
+    def loss(self,
+             cls_scores,
+             bbox_preds,
+             gt_bboxes,
+             img_metas,
+             cfg,
+             gt_bboxes_ignore=None):
+        losses = super(RPNHead, self).loss(
+            cls_scores,
+            bbox_preds,
+            gt_bboxes,
+            None,
+            img_metas,
+            cfg,
+            gt_bboxes_ignore=gt_bboxes_ignore)
         return dict(
             loss_rpn_cls=losses['loss_cls'], loss_rpn_reg=losses['loss_reg'])
 
diff --git a/mmdet/models/anchor_heads/ssd_head.py b/mmdet/models/anchor_heads/ssd_head.py
index ed1b27b..1f704f7 100644
--- a/mmdet/models/anchor_heads/ssd_head.py
+++ b/mmdet/models/anchor_heads/ssd_head.py
@@ -130,8 +130,14 @@ class SSDHead(AnchorHead):
             avg_factor=num_total_samples)
         return loss_cls[None], loss_reg
 
-    def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
-             cfg):
+    def loss(self,
+             cls_scores,
+             bbox_preds,
+             gt_bboxes,
+             gt_labels,
+             img_metas,
+             cfg,
+             gt_bboxes_ignore=None):
         featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
         assert len(featmap_sizes) == len(self.anchor_generators)
 
@@ -145,6 +151,7 @@ class SSDHead(AnchorHead):
             self.target_means,
             self.target_stds,
             cfg,
+            gt_bboxes_ignore_list=gt_bboxes_ignore,
             gt_labels_list=gt_labels,
             label_channels=1,
             sampling=False,
diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
index 705a078..6df9e22 100644
--- a/mmdet/models/detectors/cascade_rcnn.py
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -109,8 +109,8 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
                       img,
                       img_meta,
                       gt_bboxes,
-                      gt_bboxes_ignore,
                       gt_labels,
+                      gt_bboxes_ignore=None,
                       gt_masks=None,
                       proposals=None):
         x = self.extract_feat(img)
@@ -121,7 +121,8 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
             rpn_outs = self.rpn_head(x)
             rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                           self.train_cfg.rpn)
-            rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
+            rpn_losses = self.rpn_head.loss(
+                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
             losses.update(rpn_losses)
 
             proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
diff --git a/mmdet/models/detectors/rpn.py b/mmdet/models/detectors/rpn.py
index b882024..36b92d0 100644
--- a/mmdet/models/detectors/rpn.py
+++ b/mmdet/models/detectors/rpn.py
@@ -38,7 +38,11 @@ class RPN(BaseDetector, RPNTestMixin):
             x = self.neck(x)
         return x
 
-    def forward_train(self, img, img_meta, gt_bboxes=None):
+    def forward_train(self,
+                      img,
+                      img_meta,
+                      gt_bboxes=None,
+                      gt_bboxes_ignore=None):
         if self.train_cfg.rpn.get('debug', False):
             self.rpn_head.debug_imgs = tensor2imgs(img)
 
@@ -46,7 +50,8 @@ class RPN(BaseDetector, RPNTestMixin):
         rpn_outs = self.rpn_head(x)
 
         rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn)
-        losses = self.rpn_head.loss(*rpn_loss_inputs)
+        losses = self.rpn_head.loss(
+            *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
         return losses
 
     def simple_test(self, img, img_meta, rescale=False):
diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py
index 17fe6fb..6f73b34 100644
--- a/mmdet/models/detectors/single_stage.py
+++ b/mmdet/models/detectors/single_stage.py
@@ -42,11 +42,17 @@ class SingleStageDetector(BaseDetector):
             x = self.neck(x)
         return x
 
-    def forward_train(self, img, img_metas, gt_bboxes, gt_labels):
+    def forward_train(self,
+                      img,
+                      img_metas,
+                      gt_bboxes,
+                      gt_labels,
+                      gt_bboxes_ignore=None):
         x = self.extract_feat(img)
         outs = self.bbox_head(x)
         loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, self.train_cfg)
-        losses = self.bbox_head.loss(*loss_inputs)
+        losses = self.bbox_head.loss(
+            *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
         return losses
 
     def simple_test(self, img, img_meta, rescale=False):
diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py
index 68f734f..3dca618 100644
--- a/mmdet/models/detectors/two_stage.py
+++ b/mmdet/models/detectors/two_stage.py
@@ -81,8 +81,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                       img,
                       img_meta,
                       gt_bboxes,
-                      gt_bboxes_ignore,
                       gt_labels,
+                      gt_bboxes_ignore=None,
                       gt_masks=None,
                       proposals=None):
         x = self.extract_feat(img)
@@ -94,7 +94,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
             rpn_outs = self.rpn_head(x)
             rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                           self.train_cfg.rpn)
-            rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
+            rpn_losses = self.rpn_head.loss(
+                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
             losses.update(rpn_losses)
 
             proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
@@ -108,6 +109,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
             bbox_sampler = build_sampler(
                 self.train_cfg.rcnn.sampler, context=self)
             num_imgs = img.size(0)
+            if gt_bboxes_ignore is None:
+                gt_bboxes_ignore = [None for _ in range(num_imgs)]
             sampling_results = []
             for i in range(num_imgs):
                 assign_result = bbox_assigner.assign(
-- 
GitLab