From bac11303e2776b57b730d4ffb5d41bc5ce34b079 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Wed, 17 Oct 2018 22:33:17 +0800
Subject: [PATCH] add BBoxAssigner and BBoxSampler

---
 configs/fast_mask_rcnn_r50_fpn_1x.py     |  22 +-
 configs/fast_rcnn_r50_fpn_1x.py          |  22 +-
 configs/faster_rcnn_r50_fpn_1x.py        |  43 +-
 configs/mask_rcnn_r50_fpn_1x.py          |  43 +-
 configs/rpn_r50_fpn_1x.py                |  21 +-
 mmdet/core/anchor/anchor_target.py       |  27 +-
 mmdet/core/bbox/__init__.py              |  15 +-
 mmdet/core/bbox/assignment.py            | 155 +++++++
 mmdet/core/bbox/bbox_target.py           |  50 +--
 mmdet/core/bbox/sampling.py              | 496 +++++++++--------------
 mmdet/datasets/coco.py                   |  12 +-
 mmdet/models/bbox_heads/bbox_head.py     |  12 +-
 mmdet/models/detectors/two_stage.py      |  65 ++-
 mmdet/models/mask_heads/fcn_mask_head.py |   7 +-
 14 files changed, 522 insertions(+), 468 deletions(-)
 create mode 100644 mmdet/core/bbox/assignment.py

diff --git a/configs/fast_mask_rcnn_r50_fpn_1x.py b/configs/fast_mask_rcnn_r50_fpn_1x.py
index af2070f..22215da 100644
--- a/configs/fast_mask_rcnn_r50_fpn_1x.py
+++ b/configs/fast_mask_rcnn_r50_fpn_1x.py
@@ -43,17 +43,19 @@ model = dict(
 # model training and testing settings
 train_cfg = dict(
     rcnn=dict(
+        assigner=dict(
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True,
+            pos_balance_sampling=False,
+            neg_balance_thr=0),
         mask_size=28,
-        pos_iou_thr=0.5,
-        neg_iou_thr=0.5,
-        crowd_thr=1.1,
-        roi_batch_size=512,
-        add_gt_as_proposals=True,
-        pos_fraction=0.25,
-        pos_balance_sampling=False,
-        neg_pos_ub=512,
-        neg_balance_thr=0,
-        min_pos_iou=0.5,
         pos_weight=-1,
         debug=False))
 test_cfg = dict(
diff --git a/configs/fast_rcnn_r50_fpn_1x.py b/configs/fast_rcnn_r50_fpn_1x.py
index 397ab43..27de2ff 100644
--- a/configs/fast_rcnn_r50_fpn_1x.py
+++ b/configs/fast_rcnn_r50_fpn_1x.py
@@ -32,16 +32,18 @@ model = dict(
 # model training and testing settings
 train_cfg = dict(
     rcnn=dict(
-        pos_iou_thr=0.5,
-        neg_iou_thr=0.5,
-        crowd_thr=1.1,
-        roi_batch_size=512,
-        add_gt_as_proposals=True,
-        pos_fraction=0.25,
-        pos_balance_sampling=False,
-        neg_pos_ub=512,
-        neg_balance_thr=0,
-        min_pos_iou=0.5,
+        assigner=dict(
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True,
+            pos_balance_sampling=False,
+            neg_balance_thr=0),
         pos_weight=-1,
         debug=False))
 test_cfg = dict(rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5))
diff --git a/configs/faster_rcnn_r50_fpn_1x.py b/configs/faster_rcnn_r50_fpn_1x.py
index 1c06c4c..4ab8b5a 100644
--- a/configs/faster_rcnn_r50_fpn_1x.py
+++ b/configs/faster_rcnn_r50_fpn_1x.py
@@ -42,30 +42,35 @@ model = dict(
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
-        pos_fraction=0.5,
-        pos_balance_sampling=False,
-        neg_pos_ub=256,
+        assigner=dict(
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False,
+            pos_balance_sampling=False,
+            neg_balance_thr=0),
         allowed_border=0,
-        crowd_thr=1.1,
-        anchor_batch_size=256,
-        pos_iou_thr=0.7,
-        neg_iou_thr=0.3,
-        neg_balance_thr=0,
-        min_pos_iou=0.3,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
         debug=False),
     rcnn=dict(
-        pos_iou_thr=0.5,
-        neg_iou_thr=0.5,
-        crowd_thr=1.1,
-        roi_batch_size=512,
-        add_gt_as_proposals=True,
-        pos_fraction=0.25,
-        pos_balance_sampling=False,
-        neg_pos_ub=512,
-        neg_balance_thr=0,
-        min_pos_iou=0.5,
+        assigner=dict(
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True,
+            pos_balance_sampling=False,
+            neg_balance_thr=0),
         pos_weight=-1,
         debug=False))
 test_cfg = dict(
diff --git a/configs/mask_rcnn_r50_fpn_1x.py b/configs/mask_rcnn_r50_fpn_1x.py
index 8868cf6..b190ad8 100644
--- a/configs/mask_rcnn_r50_fpn_1x.py
+++ b/configs/mask_rcnn_r50_fpn_1x.py
@@ -53,31 +53,36 @@ model = dict(
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
-        pos_fraction=0.5,
-        pos_balance_sampling=False,
-        neg_pos_ub=256,
+        assigner=dict(
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False,
+            pos_balance_sampling=False,
+            neg_balance_thr=0),
         allowed_border=0,
-        crowd_thr=1.1,
-        anchor_batch_size=256,
-        pos_iou_thr=0.7,
-        neg_iou_thr=0.3,
-        neg_balance_thr=0,
-        min_pos_iou=0.3,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
         debug=False),
     rcnn=dict(
+        assigner=dict(
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True,
+            pos_balance_sampling=False,
+            neg_balance_thr=0),
         mask_size=28,
-        pos_iou_thr=0.5,
-        neg_iou_thr=0.5,
-        crowd_thr=1.1,
-        roi_batch_size=512,
-        add_gt_as_proposals=True,
-        pos_fraction=0.25,
-        pos_balance_sampling=False,
-        neg_pos_ub=512,
-        neg_balance_thr=0,
-        min_pos_iou=0.5,
         pos_weight=-1,
         debug=False))
 test_cfg = dict(
diff --git a/configs/rpn_r50_fpn_1x.py b/configs/rpn_r50_fpn_1x.py
index 7f1b6d0..8e2b402 100644
--- a/configs/rpn_r50_fpn_1x.py
+++ b/configs/rpn_r50_fpn_1x.py
@@ -27,16 +27,19 @@ model = dict(
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
-        pos_fraction=0.5,
-        pos_balance_sampling=False,
-        neg_pos_ub=256,
+        assigner=dict(
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False,
+            pos_balance_sampling=False,
+            neg_balance_thr=0),
         allowed_border=0,
-        crowd_thr=1.1,
-        anchor_batch_size=256,
-        pos_iou_thr=0.7,
-        neg_iou_thr=0.3,
-        neg_balance_thr=0,
-        min_pos_iou=0.3,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
         debug=False))
diff --git a/mmdet/core/anchor/anchor_target.py b/mmdet/core/anchor/anchor_target.py
index ad81e39..49047b9 100644
--- a/mmdet/core/anchor/anchor_target.py
+++ b/mmdet/core/anchor/anchor_target.py
@@ -1,6 +1,6 @@
 import torch
 
-from ..bbox import bbox_assign, bbox2delta, bbox_sampling
+from ..bbox import assign_and_sample, bbox2delta
 from ..utils import multi_apply
 
 
@@ -80,27 +80,20 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
         return (None, ) * 6
     # assign gt and sample anchors
     anchors = flat_anchors[inside_flags, :]
-    assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign(
-        anchors,
-        gt_bboxes,
-        pos_iou_thr=cfg.pos_iou_thr,
-        neg_iou_thr=cfg.neg_iou_thr,
-        min_pos_iou=cfg.min_pos_iou)
-    pos_inds, neg_inds = bbox_sampling(assigned_gt_inds, cfg.anchor_batch_size,
-                                       cfg.pos_fraction, cfg.neg_pos_ub,
-                                       cfg.pos_balance_sampling, max_overlaps,
-                                       cfg.neg_balance_thr)
+    _, sampling_result = assign_and_sample(anchors, gt_bboxes, None, None, cfg)
 
+    num_valid_anchors = anchors.shape[0]
     bbox_targets = torch.zeros_like(anchors)
     bbox_weights = torch.zeros_like(anchors)
-    labels = torch.zeros_like(assigned_gt_inds)
-    label_weights = torch.zeros_like(assigned_gt_inds, dtype=anchors.dtype)
+    labels = anchors.new_zeros((num_valid_anchors, ))
+    label_weights = anchors.new_zeros((num_valid_anchors, ))
 
+    pos_inds = sampling_result.pos_inds
+    neg_inds = sampling_result.neg_inds
     if len(pos_inds) > 0:
-        pos_anchors = anchors[pos_inds, :]
-        pos_gt_bbox = gt_bboxes[assigned_gt_inds[pos_inds] - 1, :]
-        pos_bbox_targets = bbox2delta(pos_anchors, pos_gt_bbox, target_means,
-                                      target_stds)
+        pos_bbox_targets = bbox2delta(sampling_result.pos_bboxes,
+                                      sampling_result.pos_gt_bboxes,
+                                      target_means, target_stds)
         bbox_targets[pos_inds, :] = pos_bbox_targets
         bbox_weights[pos_inds, :] = 1.0
         labels[pos_inds] = 1
diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py
index a5c21dc..2ed869f 100644
--- a/mmdet/core/bbox/__init__.py
+++ b/mmdet/core/bbox/__init__.py
@@ -1,15 +1,14 @@
 from .geometry import bbox_overlaps
-from .sampling import (random_choice, bbox_assign, bbox_assign_wrt_overlaps,
-                       bbox_sampling, bbox_sampling_pos, bbox_sampling_neg,
-                       sample_bboxes)
+from .assignment import BBoxAssigner, AssignResult
+from .sampling import (BBoxSampler, SamplingResult, assign_and_sample,
+                       random_choice)
 from .transforms import (bbox2delta, delta2bbox, bbox_flip, bbox_mapping,
                          bbox_mapping_back, bbox2roi, roi2bbox, bbox2result)
 from .bbox_target import bbox_target
 
 __all__ = [
-    'bbox_overlaps', 'random_choice', 'bbox_assign',
-    'bbox_assign_wrt_overlaps', 'bbox_sampling', 'bbox_sampling_pos',
-    'bbox_sampling_neg', 'sample_bboxes', 'bbox2delta', 'delta2bbox',
-    'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox',
-    'bbox2result', 'bbox_target'
+    'bbox_overlaps', 'BBoxAssigner', 'AssignResult', 'BBoxSampler',
+    'SamplingResult', 'assign_and_sample', 'random_choice', 'bbox2delta',
+    'delta2bbox', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi',
+    'roi2bbox', 'bbox2result', 'bbox_target'
 ]
diff --git a/mmdet/core/bbox/assignment.py b/mmdet/core/bbox/assignment.py
new file mode 100644
index 0000000..62233af
--- /dev/null
+++ b/mmdet/core/bbox/assignment.py
@@ -0,0 +1,155 @@
+import torch
+
+from .geometry import bbox_overlaps
+
+
+class BBoxAssigner(object):
+    """Assign a corresponding gt bbox or background to each bbox.
+
+    Each proposals will be assigned with `-1`, `0`, or a positive integer
+    indicating the ground truth index.
+
+    - -1: don't care
+    - 0: negative sample, no assigned gt
+    - positive integer: positive sample, index (1-based) of assigned gt
+
+    Args:
+        pos_iou_thr (float): IoU threshold for positive bboxes.
+        neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
+        min_pos_iou (float): Minimum iou for a bbox to be considered as a
+            positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
+            it is usually set as pos_iou_thr
+        ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
+            `gt_bboxes_ignore` is specified). Negative values mean not
+            ignoring any bboxes.
+    """
+
+    def __init__(self,
+                 pos_iou_thr,
+                 neg_iou_thr,
+                 min_pos_iou=.0,
+                 ignore_iof_thr=-1):
+        self.pos_iou_thr = pos_iou_thr
+        self.neg_iou_thr = neg_iou_thr
+        self.min_pos_iou = min_pos_iou
+        self.ignore_iof_thr = ignore_iof_thr
+
+    def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+        """Assign gt to bboxes.
+
+        This method assign a gt bbox to every bbox (proposal/anchor), each bbox
+        will be assigned with -1, 0, or a positive number. -1 means don't care,
+        0 means negative sample, positive number is the index (1-based) of
+        assigned gt.
+        The assignment is done in following steps, the order matters.
+
+        1. assign every bbox to -1
+        2. assign proposals whose iou with all gts < neg_iou_thr to 0
+        3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
+           assign it to that bbox
+        4. for each gt bbox, assign its nearest proposals (may be more than
+           one) to itself
+
+        Args:
+            bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
+            gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+                labelled as `ignored`, e.g., crowd boxes in COCO.
+            gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+        Returns:
+            :obj:`AssignResult`: The assign result.
+        """
+        if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0:
+            raise ValueError('No gt or bboxes')
+        bboxes = bboxes[:, :4]
+        overlaps = bbox_overlaps(bboxes, gt_bboxes)
+
+        if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and (
+                gt_bboxes_ignore.numel() > 0):
+            ignore_overlaps = bbox_overlaps(
+                bboxes, gt_bboxes_ignore, mode='iof')
+            ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
+            ignore_bboxes_inds = torch.nonzero(
+                ignore_max_overlaps > self.ignore_iof_thr).squeeze()
+            if ignore_bboxes_inds.numel() > 0:
+                overlaps[ignore_bboxes_inds[:, 0], :] = -1
+
+        assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
+        return assign_result
+
+    def assign_wrt_overlaps(self, overlaps, gt_labels=None):
+        """Assign w.r.t. the overlaps of bboxes with gts.
+
+        Args:
+            overlaps (Tensor): Overlaps between n bboxes and k gt_bboxes,
+                shape(n, k).
+            gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
+
+        Returns:
+            :obj:`AssignResult`: The assign result.
+        """
+        if overlaps.numel() == 0:
+            raise ValueError('No gt or proposals')
+
+        num_bboxes, num_gts = overlaps.size(0), overlaps.size(1)
+
+        # 1. assign -1 by default
+        assigned_gt_inds = overlaps.new_full(
+            (num_bboxes, ), -1, dtype=torch.long)
+
+        assert overlaps.size() == (num_bboxes, num_gts)
+        # for each anchor, which gt best overlaps with it
+        # for each anchor, the max iou of all gts
+        max_overlaps, argmax_overlaps = overlaps.max(dim=1)
+        # for each gt, which anchor best overlaps with it
+        # for each gt, the max iou of all proposals
+        gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0)
+
+        # 2. assign negative: below
+        if isinstance(self.neg_iou_thr, float):
+            assigned_gt_inds[(max_overlaps >= 0)
+                             & (max_overlaps < self.neg_iou_thr)] = 0
+        elif isinstance(self.neg_iou_thr, tuple):
+            assert len(self.neg_iou_thr) == 2
+            assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
+                             & (max_overlaps < self.neg_iou_thr[1])] = 0
+
+        # 3. assign positive: above positive IoU threshold
+        pos_inds = max_overlaps >= self.pos_iou_thr
+        assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
+
+        # 4. assign fg: for each gt, proposals with highest IoU
+        for i in range(num_gts):
+            if gt_max_overlaps[i] >= self.min_pos_iou:
+                assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1
+
+        if gt_labels is not None:
+            assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, ))
+            pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
+            if pos_inds.numel() > 0:
+                assigned_labels[pos_inds] = gt_labels[
+                    assigned_gt_inds[pos_inds] - 1]
+        else:
+            assigned_labels = None
+
+        return AssignResult(
+            num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+
+
+class AssignResult(object):
+
+    def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
+        self.num_gts = num_gts
+        self.gt_inds = gt_inds
+        self.max_overlaps = max_overlaps
+        self.labels = labels
+
+    def add_gt_(self, gt_labels):
+        self_inds = torch.arange(
+            1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
+        self.gt_inds = torch.cat([self_inds, self.gt_inds])
+        self.max_overlaps = torch.cat(
+            [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
+        if self.labels is not None:
+            self.labels = torch.cat([gt_labels, self.labels])
diff --git a/mmdet/core/bbox/bbox_target.py b/mmdet/core/bbox/bbox_target.py
index 2e205c3..4a0450d 100644
--- a/mmdet/core/bbox/bbox_target.py
+++ b/mmdet/core/bbox/bbox_target.py
@@ -4,23 +4,23 @@ from .transforms import bbox2delta
 from ..utils import multi_apply
 
 
-def bbox_target(pos_proposals_list,
-                neg_proposals_list,
+def bbox_target(pos_bboxes_list,
+                neg_bboxes_list,
                 pos_gt_bboxes_list,
                 pos_gt_labels_list,
                 cfg,
-                reg_num_classes=1,
+                reg_classes=1,
                 target_means=[.0, .0, .0, .0],
                 target_stds=[1.0, 1.0, 1.0, 1.0],
                 concat=True):
     labels, label_weights, bbox_targets, bbox_weights = multi_apply(
-        proposal_target_single,
-        pos_proposals_list,
-        neg_proposals_list,
+        bbox_target_single,
+        pos_bboxes_list,
+        neg_bboxes_list,
         pos_gt_bboxes_list,
         pos_gt_labels_list,
         cfg=cfg,
-        reg_num_classes=reg_num_classes,
+        reg_classes=reg_classes,
         target_means=target_means,
         target_stds=target_stds)
 
@@ -32,34 +32,34 @@ def bbox_target(pos_proposals_list,
     return labels, label_weights, bbox_targets, bbox_weights
 
 
-def proposal_target_single(pos_proposals,
-                           neg_proposals,
-                           pos_gt_bboxes,
-                           pos_gt_labels,
-                           cfg,
-                           reg_num_classes=1,
-                           target_means=[.0, .0, .0, .0],
-                           target_stds=[1.0, 1.0, 1.0, 1.0]):
-    num_pos = pos_proposals.size(0)
-    num_neg = neg_proposals.size(0)
+def bbox_target_single(pos_bboxes,
+                       neg_bboxes,
+                       pos_gt_bboxes,
+                       pos_gt_labels,
+                       cfg,
+                       reg_classes=1,
+                       target_means=[.0, .0, .0, .0],
+                       target_stds=[1.0, 1.0, 1.0, 1.0]):
+    num_pos = pos_bboxes.size(0)
+    num_neg = neg_bboxes.size(0)
     num_samples = num_pos + num_neg
-    labels = pos_proposals.new_zeros(num_samples, dtype=torch.long)
-    label_weights = pos_proposals.new_zeros(num_samples)
-    bbox_targets = pos_proposals.new_zeros(num_samples, 4)
-    bbox_weights = pos_proposals.new_zeros(num_samples, 4)
+    labels = pos_bboxes.new_zeros(num_samples, dtype=torch.long)
+    label_weights = pos_bboxes.new_zeros(num_samples)
+    bbox_targets = pos_bboxes.new_zeros(num_samples, 4)
+    bbox_weights = pos_bboxes.new_zeros(num_samples, 4)
     if num_pos > 0:
         labels[:num_pos] = pos_gt_labels
         pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
         label_weights[:num_pos] = pos_weight
-        pos_bbox_targets = bbox2delta(pos_proposals, pos_gt_bboxes,
-                                      target_means, target_stds)
+        pos_bbox_targets = bbox2delta(pos_bboxes, pos_gt_bboxes, target_means,
+                                      target_stds)
         bbox_targets[:num_pos, :] = pos_bbox_targets
         bbox_weights[:num_pos, :] = 1
     if num_neg > 0:
         label_weights[-num_neg:] = 1.0
-    if reg_num_classes > 1:
+    if reg_classes > 1:
         bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights,
-                                                   labels, reg_num_classes)
+                                                   labels, reg_classes)
 
     return labels, label_weights, bbox_targets, bbox_weights
 
diff --git a/mmdet/core/bbox/sampling.py b/mmdet/core/bbox/sampling.py
index 976cd95..63d2279 100644
--- a/mmdet/core/bbox/sampling.py
+++ b/mmdet/core/bbox/sampling.py
@@ -1,7 +1,7 @@
 import numpy as np
 import torch
 
-from .geometry import bbox_overlaps
+from .assignment import BBoxAssigner
 
 
 def random_choice(gallery, num):
@@ -21,323 +21,207 @@ def random_choice(gallery, num):
     return gallery[rand_inds]
 
 
-def bbox_assign(proposals,
-                gt_bboxes,
-                gt_bboxes_ignore=None,
-                gt_labels=None,
-                pos_iou_thr=0.5,
-                neg_iou_thr=0.5,
-                min_pos_iou=.0,
-                crowd_thr=-1):
-    """Assign a corresponding gt bbox or background to each proposal/anchor.
+def assign_and_sample(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg):
+    bbox_assigner = BBoxAssigner(**cfg.assigner)
+    bbox_sampler = BBoxSampler(**cfg.sampler)
+    assign_result = bbox_assigner.assign(bboxes, gt_bboxes, gt_bboxes_ignore,
+                                         gt_labels)
+    sampling_result = bbox_sampler.sample(assign_result, bboxes, gt_bboxes,
+                                          gt_labels)
+    return assign_result, sampling_result
 
-    Each proposals will be assigned with `-1`, `0`, or a positive integer.
 
-    - -1: don't care
-    - 0: negative sample, no assigned gt
-    - positive integer: positive sample, index (1-based) of assigned gt
-
-    If `gt_bboxes_ignore` is specified, bboxes which have iof (intersection
-    over foreground) with `gt_bboxes_ignore` above `crowd_thr` will be ignored.
-
-    Args:
-        proposals (Tensor): Proposals or RPN anchors, shape (n, 4).
-        gt_bboxes (Tensor): Ground truth bboxes, shape (k, 4).
-        gt_bboxes_ignore (Tensor, optional): shape(m, 4).
-        gt_labels (Tensor, optional): shape (k, ).
-        pos_iou_thr (float): IoU threshold for positive bboxes.
-        neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
-        min_pos_iou (float): Minimum iou for a bbox to be considered as a
-            positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
-            it is usually set as pos_iou_thr
-        crowd_thr (float): IoF threshold for ignoring bboxes. Negative value
-            for not ignoring any bboxes.
-
-    Returns:
-        tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
-    """
-
-    # calculate overlaps between the proposals and the gt boxes
-    overlaps = bbox_overlaps(proposals, gt_bboxes)
-    if overlaps.numel() == 0:
-        raise ValueError('No gt bbox or proposals')
-
-    # ignore proposals according to crowd bboxes
-    if (crowd_thr > 0) and (gt_bboxes_ignore is
-                            not None) and (gt_bboxes_ignore.numel() > 0):
-        crowd_overlaps = bbox_overlaps(proposals, gt_bboxes_ignore, mode='iof')
-        crowd_max_overlaps, _ = crowd_overlaps.max(dim=1)
-        crowd_bboxes_inds = torch.nonzero(
-            crowd_max_overlaps > crowd_thr).long()
-        if crowd_bboxes_inds.numel() > 0:
-            overlaps[crowd_bboxes_inds, :] = -1
-
-    return bbox_assign_wrt_overlaps(overlaps, gt_labels, pos_iou_thr,
-                                    neg_iou_thr, min_pos_iou)
-
-
-def bbox_assign_wrt_overlaps(overlaps,
-                             gt_labels=None,
-                             pos_iou_thr=0.5,
-                             neg_iou_thr=0.5,
-                             min_pos_iou=.0):
-    """Assign a corresponding gt bbox or background to each proposal/anchor.
-
-    This method assign a gt bbox to every proposal, each proposals will be
-    assigned with -1, 0, or a positive number. -1 means don't care, 0 means
-    negative sample, positive number is the index (1-based) of assigned gt.
-    The assignment is done in following steps, the order matters:
-
-    1. assign every anchor to -1
-    2. assign proposals whose iou with all gts < neg_iou_thr to 0
-    3. for each anchor, if the iou with its nearest gt >= pos_iou_thr,
-    assign it to that bbox
-    4. for each gt bbox, assign its nearest proposals(may be more than one)
-    to itself
-
-    Args:
-        overlaps (Tensor): Overlaps between n proposals and k gt_bboxes,
-            shape(n, k).
-        gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
-        pos_iou_thr (float): IoU threshold for positive bboxes.
-        neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
-        min_pos_iou (float): Minimum IoU for a bbox to be considered as a
-            positive bbox. This argument only affects the 4th step.
-
-    Returns:
-        tuple: (assigned_gt_inds, [assigned_labels], argmax_overlaps,
-            max_overlaps), shape (n, )
-    """
-    num_bboxes, num_gts = overlaps.size(0), overlaps.size(1)
-    # 1. assign -1 by default
-    assigned_gt_inds = overlaps.new(num_bboxes).long().fill_(-1)
-
-    if overlaps.numel() == 0:
-        raise ValueError('No gt bbox or proposals')
-
-    assert overlaps.size() == (num_bboxes, num_gts)
-    # for each anchor, which gt best overlaps with it
-    # for each anchor, the max iou of all gts
-    max_overlaps, argmax_overlaps = overlaps.max(dim=1)
-    # for each gt, which anchor best overlaps with it
-    # for each gt, the max iou of all proposals
-    gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0)
-
-    # 2. assign negative: below
-    if isinstance(neg_iou_thr, float):
-        assigned_gt_inds[(max_overlaps >= 0)
-                         & (max_overlaps < neg_iou_thr)] = 0
-    elif isinstance(neg_iou_thr, tuple):
-        assert len(neg_iou_thr) == 2
-        assigned_gt_inds[(max_overlaps >= neg_iou_thr[0])
-                         & (max_overlaps < neg_iou_thr[1])] = 0
-
-    # 3. assign positive: above positive IoU threshold
-    pos_inds = max_overlaps >= pos_iou_thr
-    assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
-
-    # 4. assign fg: for each gt, proposals with highest IoU
-    for i in range(num_gts):
-        if gt_max_overlaps[i] >= min_pos_iou:
-            assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1
-
-    if gt_labels is None:
-        return assigned_gt_inds, argmax_overlaps, max_overlaps
-    else:
-        assigned_labels = assigned_gt_inds.new(num_bboxes).fill_(0)
-        pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
-        if pos_inds.numel() > 0:
-            assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
-                                                  1]
-        return assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps
-
-
-def bbox_sampling_pos(assigned_gt_inds, num_expected, balance_sampling=True):
-    """Balance sampling for positive bboxes/anchors.
-
-    1. calculate average positive num for each gt: num_per_gt
-    2. sample at most num_per_gt positives for each gt
-    3. random sampling from rest anchors if not enough fg
-    """
-    pos_inds = torch.nonzero(assigned_gt_inds > 0)
-    if pos_inds.numel() != 0:
-        pos_inds = pos_inds.squeeze(1)
-    if pos_inds.numel() <= num_expected:
-        return pos_inds
-    elif not balance_sampling:
-        return random_choice(pos_inds, num_expected)
-    else:
-        unique_gt_inds = torch.unique(assigned_gt_inds[pos_inds].cpu())
-        num_gts = len(unique_gt_inds)
-        num_per_gt = int(round(num_expected / float(num_gts)) + 1)
-        sampled_inds = []
-        for i in unique_gt_inds:
-            inds = torch.nonzero(assigned_gt_inds == i.item())
-            if inds.numel() != 0:
-                inds = inds.squeeze(1)
-            else:
-                continue
-            if len(inds) > num_per_gt:
-                inds = random_choice(inds, num_per_gt)
-            sampled_inds.append(inds)
-        sampled_inds = torch.cat(sampled_inds)
-        if len(sampled_inds) < num_expected:
-            num_extra = num_expected - len(sampled_inds)
-            extra_inds = np.array(
-                list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
-            if len(extra_inds) > num_extra:
-                extra_inds = random_choice(extra_inds, num_extra)
-            extra_inds = torch.from_numpy(extra_inds).to(
-                assigned_gt_inds.device).long()
-            sampled_inds = torch.cat([sampled_inds, extra_inds])
-        elif len(sampled_inds) > num_expected:
-            sampled_inds = random_choice(sampled_inds, num_expected)
-        return sampled_inds
-
-
-def bbox_sampling_neg(assigned_gt_inds,
-                      num_expected,
-                      max_overlaps=None,
-                      balance_thr=0,
-                      hard_fraction=0.5):
-    """Balance sampling for negative bboxes/anchors.
-
-    Negative samples are split into 2 set: hard (balance_thr <= iou <
-    neg_iou_thr) and easy(iou < balance_thr). The sampling ratio is controlled
-    by `hard_fraction`.
-    """
-    neg_inds = torch.nonzero(assigned_gt_inds == 0)
-    if neg_inds.numel() != 0:
-        neg_inds = neg_inds.squeeze(1)
-    if len(neg_inds) <= num_expected:
-        return neg_inds
-    elif balance_thr <= 0:
-        # uniform sampling among all negative samples
-        return random_choice(neg_inds, num_expected)
-    else:
-        assert max_overlaps is not None
-        max_overlaps = max_overlaps.cpu().numpy()
-        # balance sampling for negative samples
-        neg_set = set(neg_inds.cpu().numpy())
-        easy_set = set(
-            np.where(
-                np.logical_and(max_overlaps >= 0,
-                               max_overlaps < balance_thr))[0])
-        hard_set = set(np.where(max_overlaps >= balance_thr)[0])
-        easy_neg_inds = list(easy_set & neg_set)
-        hard_neg_inds = list(hard_set & neg_set)
-
-        num_expected_hard = int(num_expected * hard_fraction)
-        if len(hard_neg_inds) > num_expected_hard:
-            sampled_hard_inds = random_choice(hard_neg_inds, num_expected_hard)
-        else:
-            sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int)
-        num_expected_easy = num_expected - len(sampled_hard_inds)
-        if len(easy_neg_inds) > num_expected_easy:
-            sampled_easy_inds = random_choice(easy_neg_inds, num_expected_easy)
-        else:
-            sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int)
-        sampled_inds = np.concatenate((sampled_easy_inds, sampled_hard_inds))
-        if len(sampled_inds) < num_expected:
-            num_extra = num_expected - len(sampled_inds)
-            extra_inds = np.array(list(neg_set - set(sampled_inds)))
-            if len(extra_inds) > num_extra:
-                extra_inds = random_choice(extra_inds, num_extra)
-            sampled_inds = np.concatenate((sampled_inds, extra_inds))
-        sampled_inds = torch.from_numpy(sampled_inds).long().to(
-            assigned_gt_inds.device)
-        return sampled_inds
-
-
-def bbox_sampling(assigned_gt_inds,
-                  num_expected,
-                  pos_fraction,
-                  neg_pos_ub,
-                  pos_balance_sampling=True,
-                  max_overlaps=None,
-                  neg_balance_thr=0,
-                  neg_hard_fraction=0.5):
+class BBoxSampler(object):
     """Sample positive and negative bboxes given assigned results.
 
     Args:
-        assigned_gt_inds (Tensor): Assigned gt indices for each bbox.
-        num_expected (int): Expected total samples (pos and neg).
         pos_fraction (float): Positive sample fraction.
         neg_pos_ub (float): Negative/Positive upper bound.
-        pos_balance_sampling(bool): Whether to sample positive samples around
+        pos_balance_sampling (bool): Whether to sample positive samples around
             each gt bbox evenly.
-        max_overlaps (Tensor, optional): For each bbox, the max IoU of all gts.
-            Used for negative balance sampling only.
         neg_balance_thr (float, optional): IoU threshold for simple/hard
             negative balance sampling.
         neg_hard_fraction (float, optional): Fraction of hard negative samples
             for negative balance sampling.
-
-    Returns:
-        tuple[Tensor]: positive bbox indices, negative bbox indices.
-    """
-    num_expected_pos = int(num_expected * pos_fraction)
-    pos_inds = bbox_sampling_pos(assigned_gt_inds, num_expected_pos,
-                                 pos_balance_sampling)
-    # We found that sampled indices have duplicated items occasionally.
-    # (mab be a bug of PyTorch)
-    pos_inds = pos_inds.unique()
-    num_sampled_pos = pos_inds.numel()
-    num_neg_max = int(
-        neg_pos_ub *
-        num_sampled_pos) if num_sampled_pos > 0 else int(neg_pos_ub)
-    num_expected_neg = min(num_neg_max, num_expected - num_sampled_pos)
-    neg_inds = bbox_sampling_neg(assigned_gt_inds, num_expected_neg,
-                                 max_overlaps, neg_balance_thr,
-                                 neg_hard_fraction)
-    neg_inds = neg_inds.unique()
-    return pos_inds, neg_inds
-
-
-def sample_bboxes(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg):
-    """Sample positive and negative bboxes.
-
-    This is a simple implementation of bbox sampling given candidates and
-    ground truth bboxes, which includes 3 steps.
-
-    1. Assign gt to each bbox.
-    2. Add gt bboxes to the sampling pool (optional).
-    3. Perform positive and negative sampling.
-
-    Args:
-        bboxes (Tensor): Boxes to be sampled from.
-        gt_bboxes (Tensor): Ground truth bboxes.
-        gt_bboxes_ignore (Tensor): Ignored ground truth bboxes. In MS COCO,
-            `crowd` bboxes are considered as ignored.
-        gt_labels (Tensor): Class labels of ground truth bboxes.
-        cfg (dict): Sampling configs.
-
-    Returns:
-        tuple[Tensor]: pos_bboxes, neg_bboxes, pos_assigned_gt_inds,
-            pos_gt_bboxes, pos_gt_labels
     """
-    bboxes = bboxes[:, :4]
-    assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
-        bbox_assign(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels,
-                    cfg.pos_iou_thr, cfg.neg_iou_thr, cfg.min_pos_iou,
-                    cfg.crowd_thr)
-
-    if cfg.add_gt_as_proposals:
-        bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
-        gt_assign_self = torch.arange(
-            1, len(gt_labels) + 1, dtype=torch.long, device=bboxes.device)
-        assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
-        assigned_labels = torch.cat([gt_labels, assigned_labels])
 
-    pos_inds, neg_inds = bbox_sampling(
-        assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction, cfg.neg_pos_ub,
-        cfg.pos_balance_sampling, max_overlaps, cfg.neg_balance_thr)
-
-    pos_bboxes = bboxes[pos_inds]
-    neg_bboxes = bboxes[neg_inds]
-    pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
-    pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
-    pos_gt_labels = assigned_labels[pos_inds]
+    def __init__(self,
+                 num,
+                 pos_fraction,
+                 neg_pos_ub=-1,
+                 add_gt_as_proposals=True,
+                 pos_balance_sampling=False,
+                 neg_balance_thr=0,
+                 neg_hard_fraction=0.5):
+        self.num = num
+        self.pos_fraction = pos_fraction
+        self.neg_pos_ub = neg_pos_ub
+        self.add_gt_as_proposals = add_gt_as_proposals
+        self.pos_balance_sampling = pos_balance_sampling
+        self.neg_balance_thr = neg_balance_thr
+        self.neg_hard_fraction = neg_hard_fraction
+
+    def _sample_pos(self, assign_result, num_expected):
+        """Balance sampling for positive bboxes/anchors.
+
+        1. calculate average positive num for each gt: num_per_gt
+        2. sample at most num_per_gt positives for each gt
+        3. random sampling from rest anchors if not enough fg
+        """
+        pos_inds = torch.nonzero(assign_result.gt_inds > 0)
+        if pos_inds.numel() != 0:
+            pos_inds = pos_inds.squeeze(1)
+        if pos_inds.numel() <= num_expected:
+            return pos_inds
+        elif not self.pos_balance_sampling:
+            return random_choice(pos_inds, num_expected)
+        else:
+            unique_gt_inds = torch.unique(
+                assign_result.gt_inds[pos_inds].cpu())
+            num_gts = len(unique_gt_inds)
+            num_per_gt = int(round(num_expected / float(num_gts)) + 1)
+            sampled_inds = []
+            for i in unique_gt_inds:
+                inds = torch.nonzero(assign_result.gt_inds == i.item())
+                if inds.numel() != 0:
+                    inds = inds.squeeze(1)
+                else:
+                    continue
+                if len(inds) > num_per_gt:
+                    inds = random_choice(inds, num_per_gt)
+                sampled_inds.append(inds)
+            sampled_inds = torch.cat(sampled_inds)
+            if len(sampled_inds) < num_expected:
+                num_extra = num_expected - len(sampled_inds)
+                extra_inds = np.array(
+                    list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
+                if len(extra_inds) > num_extra:
+                    extra_inds = random_choice(extra_inds, num_extra)
+                extra_inds = torch.from_numpy(extra_inds).to(
+                    assign_result.gt_inds.device).long()
+                sampled_inds = torch.cat([sampled_inds, extra_inds])
+            elif len(sampled_inds) > num_expected:
+                sampled_inds = random_choice(sampled_inds, num_expected)
+            return sampled_inds
+
+    def _sample_neg(self, assign_result, num_expected):
+        """Balance sampling for negative bboxes/anchors.
+
+        Negative samples are split into 2 set: hard (balance_thr <= iou <
+        neg_iou_thr) and easy (iou < balance_thr). The sampling ratio is
+        controlled by `hard_fraction`.
+        """
+        neg_inds = torch.nonzero(assign_result.gt_inds == 0)
+        if neg_inds.numel() != 0:
+            neg_inds = neg_inds.squeeze(1)
+        if len(neg_inds) <= num_expected:
+            return neg_inds
+        elif self.neg_balance_thr <= 0:
+            # uniform sampling among all negative samples
+            return random_choice(neg_inds, num_expected)
+        else:
+            max_overlaps = assign_result.max_overlaps.cpu().numpy()
+            # balance sampling for negative samples
+            neg_set = set(neg_inds.cpu().numpy())
+            easy_set = set(
+                np.where(
+                    np.logical_and(max_overlaps >= 0,
+                                   max_overlaps < self.neg_balance_thr))[0])
+            hard_set = set(np.where(max_overlaps >= self.neg_balance_thr)[0])
+            easy_neg_inds = list(easy_set & neg_set)
+            hard_neg_inds = list(hard_set & neg_set)
+
+            num_expected_hard = int(num_expected * self.neg_hard_fraction)
+            if len(hard_neg_inds) > num_expected_hard:
+                sampled_hard_inds = random_choice(hard_neg_inds,
+                                                  num_expected_hard)
+            else:
+                sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int)
+            num_expected_easy = num_expected - len(sampled_hard_inds)
+            if len(easy_neg_inds) > num_expected_easy:
+                sampled_easy_inds = random_choice(easy_neg_inds,
+                                                  num_expected_easy)
+            else:
+                sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int)
+            sampled_inds = np.concatenate((sampled_easy_inds,
+                                           sampled_hard_inds))
+            if len(sampled_inds) < num_expected:
+                num_extra = num_expected - len(sampled_inds)
+                extra_inds = np.array(list(neg_set - set(sampled_inds)))
+                if len(extra_inds) > num_extra:
+                    extra_inds = random_choice(extra_inds, num_extra)
+                sampled_inds = np.concatenate((sampled_inds, extra_inds))
+            sampled_inds = torch.from_numpy(sampled_inds).long().to(
+                assign_result.gt_inds.device)
+            return sampled_inds
+
+    def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None):
+        """Sample positive and negative bboxes.
+
+        This is a simple implementation of bbox sampling given candidates,
+        assigning results and ground truth bboxes.
+
+        1. Assign gt to each bbox.
+        2. Add gt bboxes to the sampling pool (optional).
+        3. Perform positive and negative sampling.
+
+        Args:
+            assign_result (:obj:`AssignResult`): Bbox assigning results.
+            bboxes (Tensor): Boxes to be sampled from.
+            gt_bboxes (Tensor): Ground truth bboxes.
+            gt_labels (Tensor, optional): Class labels of ground truth bboxes.
+
+        Returns:
+            :obj:`SamplingResult`: Sampling result.
+        """
+        bboxes = bboxes[:, :4]
+
+        gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
+        if self.add_gt_as_proposals:
+            bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
+            assign_result.add_gt_(gt_labels)
+            gt_flags = torch.cat([
+                bboxes.new_ones((gt_bboxes.shape[0], ), dtype=torch.uint8),
+                gt_flags
+            ])
+
+        num_expected_pos = int(self.num * self.pos_fraction)
+        pos_inds = self._sample_pos(assign_result, num_expected_pos)
+        # We found that sampled indices have duplicated items occasionally.
+        # (mab be a bug of PyTorch)
+        pos_inds = pos_inds.unique()
+        num_sampled_pos = pos_inds.numel()
+        num_expected_neg = self.num - num_sampled_pos
+        if self.neg_pos_ub >= 0:
+            num_neg_max = int(self.neg_pos_ub *
+                              num_sampled_pos) if num_sampled_pos > 0 else int(
+                                  self.neg_pos_ub)
+            num_expected_neg = min(num_neg_max, num_expected_neg)
+        neg_inds = self._sample_neg(assign_result, num_expected_neg)
+        neg_inds = neg_inds.unique()
+
+        return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
+                              assign_result, gt_flags)
+
+
+class SamplingResult(object):
+
+    def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
+                 gt_flags):
+        self.pos_inds = pos_inds
+        self.neg_inds = neg_inds
+        self.pos_bboxes = bboxes[pos_inds]
+        self.neg_bboxes = bboxes[neg_inds]
+        self.pos_is_gt = gt_flags[pos_inds]
+
+        self.num_gts = gt_bboxes.shape[0]
+        self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
+        self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
+        if assign_result.labels is not None:
+            self.pos_gt_labels = assign_result.labels[pos_inds]
+        else:
+            self.pos_gt_labels = None
 
-    return (pos_bboxes, neg_bboxes, pos_assigned_gt_inds, pos_gt_bboxes,
-            pos_gt_labels)
+    @property
+    def bboxes(self):
+        return torch.cat([self.pos_bboxes, self.neg_bboxes])
diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py
index 0ee92f9..fbb14aa 100644
--- a/mmdet/datasets/coco.py
+++ b/mmdet/datasets/coco.py
@@ -215,7 +215,7 @@ class CocoDataset(Dataset):
                         'proposals should have shapes (n, 4) or (n, 5), '
                         'but found {}'.format(proposals.shape))
                 if proposals.shape[1] == 5:
-                    scores = proposals[:, 4]
+                    scores = proposals[:, 4, None]
                     proposals = proposals[:, :4]
                 else:
                     scores = None
@@ -237,8 +237,8 @@ class CocoDataset(Dataset):
             if self.proposals is not None:
                 proposals = self.bbox_transform(proposals, img_shape,
                                                 scale_factor, flip)
-                proposals = np.hstack([proposals, scores[:, None]
-                                       ]) if scores is not None else proposals
+                proposals = np.hstack(
+                    [proposals, scores]) if scores is not None else proposals
             gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor,
                                             flip)
             gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape,
@@ -295,14 +295,14 @@ class CocoDataset(Dataset):
                 flip=flip)
             if proposal is not None:
                 if proposal.shape[1] == 5:
-                    score = proposal[:, 4]
+                    score = proposal[:, 4, None]
                     proposal = proposal[:, :4]
                 else:
                     score = None
                 _proposal = self.bbox_transform(proposal, img_shape,
                                                 scale_factor, flip)
-                _proposal = np.hstack([_proposal, score[:, None]
-                                       ]) if score is not None else _proposal
+                _proposal = np.hstack(
+                    [_proposal, score]) if score is not None else _proposal
                 _proposal = to_tensor(_proposal)
             else:
                 _proposal = None
diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py
index 67dba03..9b423bd 100644
--- a/mmdet/models/bbox_heads/bbox_head.py
+++ b/mmdet/models/bbox_heads/bbox_head.py
@@ -59,16 +59,20 @@ class BBoxHead(nn.Module):
         bbox_pred = self.fc_reg(x) if self.with_reg else None
         return cls_score, bbox_pred
 
-    def get_bbox_target(self, pos_proposals, neg_proposals, pos_gt_bboxes,
-                        pos_gt_labels, rcnn_train_cfg):
-        reg_num_classes = 1 if self.reg_class_agnostic else self.num_classes
+    def get_target(self, sampling_results, gt_bboxes, gt_labels,
+                   rcnn_train_cfg):
+        pos_proposals = [res.pos_bboxes for res in sampling_results]
+        neg_proposals = [res.neg_bboxes for res in sampling_results]
+        pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
+        pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
+        reg_classes = 1 if self.reg_class_agnostic else self.num_classes
         cls_reg_targets = bbox_target(
             pos_proposals,
             neg_proposals,
             pos_gt_bboxes,
             pos_gt_labels,
             rcnn_train_cfg,
-            reg_num_classes,
+            reg_classes,
             target_means=self.target_means,
             target_stds=self.target_stds)
         return cls_reg_targets
diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py
index 48a818d..064ee0e 100644
--- a/mmdet/models/detectors/two_stage.py
+++ b/mmdet/models/detectors/two_stage.py
@@ -4,7 +4,7 @@ import torch.nn as nn
 from .base import BaseDetector
 from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
 from .. import builder
-from mmdet.core import sample_bboxes, bbox2roi, bbox2result, multi_apply
+from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply)
 
 
 class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
@@ -80,10 +80,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                       gt_labels,
                       gt_masks=None,
                       proposals=None):
-        losses = dict()
-
         x = self.extract_feat(img)
 
+        losses = dict()
+
+        # RPN forward and loss
         if self.with_rpn:
             rpn_outs = self.rpn_head(x)
             rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
@@ -96,44 +97,43 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
         else:
             proposal_list = proposals
 
+        # assign gts and sample proposals
+        if self.with_bbox or self.with_mask:
+            assign_results, sampling_results = multi_apply(
+                assign_and_sample,
+                proposal_list,
+                gt_bboxes,
+                gt_bboxes_ignore,
+                gt_labels,
+                cfg=self.train_cfg.rcnn)
+
+        # bbox head forward and loss
         if self.with_bbox:
-            (pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes,
-             pos_gt_labels) = multi_apply(
-                 sample_bboxes,
-                 proposal_list,
-                 gt_bboxes,
-                 gt_bboxes_ignore,
-                 gt_labels,
-                 cfg=self.train_cfg.rcnn)
-            (labels, label_weights, bbox_targets,
-             bbox_weights) = self.bbox_head.get_bbox_target(
-                 pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
-                 self.train_cfg.rcnn)
-
-            rois = bbox2roi([
-                torch.cat([pos, neg], dim=0)
-                for pos, neg in zip(pos_proposals, neg_proposals)
-            ])
-            # TODO: a more flexible way to configurate feat maps
-            roi_feats = self.bbox_roi_extractor(
+            rois = bbox2roi([res.bboxes for res in sampling_results])
+            # TODO: a more flexible way to decide which feature maps to use
+            bbox_feats = self.bbox_roi_extractor(
                 x[:self.bbox_roi_extractor.num_inputs], rois)
-            cls_score, bbox_pred = self.bbox_head(roi_feats)
+            cls_score, bbox_pred = self.bbox_head(bbox_feats)
 
-            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels,
-                                            label_weights, bbox_targets,
-                                            bbox_weights)
+            bbox_targets = self.bbox_head.get_target(
+                sampling_results, gt_bboxes, gt_labels, self.train_cfg.rcnn)
+            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
+                                            *bbox_targets)
             losses.update(loss_bbox)
 
+        # mask head forward and loss
         if self.with_mask:
-            mask_targets = self.mask_head.get_mask_target(
-                pos_proposals, pos_assigned_gt_inds, gt_masks,
-                self.train_cfg.rcnn)
-            pos_rois = bbox2roi(pos_proposals)
+            pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
             mask_feats = self.mask_roi_extractor(
                 x[:self.mask_roi_extractor.num_inputs], pos_rois)
             mask_pred = self.mask_head(mask_feats)
+
+            mask_targets = self.mask_head.get_target(
+                sampling_results, gt_masks, self.train_cfg.rcnn)
+            pos_labels = torch.cat(
+                [res.pos_gt_labels for res in sampling_results])
             loss_mask = self.mask_head.loss(mask_pred, mask_targets,
-                                            torch.cat(pos_gt_labels))
+                                            pos_labels)
             losses.update(loss_mask)
 
         return losses
@@ -145,8 +145,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
         x = self.extract_feat(img)
 
         proposal_list = self.simple_test_rpn(
-            x, img_meta,
-            self.test_cfg.rpn) if proposals is None else proposals
+            x, img_meta, self.test_cfg.rpn) if proposals is None else proposals
 
         det_bboxes, det_labels = self.simple_test_bboxes(
             x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
diff --git a/mmdet/models/mask_heads/fcn_mask_head.py b/mmdet/models/mask_heads/fcn_mask_head.py
index ba46bea..2c90a03 100644
--- a/mmdet/models/mask_heads/fcn_mask_head.py
+++ b/mmdet/models/mask_heads/fcn_mask_head.py
@@ -86,8 +86,11 @@ class FCNMaskHead(nn.Module):
         mask_pred = self.conv_logits(x)
         return mask_pred
 
-    def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks,
-                        rcnn_train_cfg):
+    def get_target(self, sampling_results, gt_masks, rcnn_train_cfg):
+        pos_proposals = [res.pos_bboxes for res in sampling_results]
+        pos_assigned_gt_inds = [
+            res.pos_assigned_gt_inds for res in sampling_results
+        ]
         mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
                                    gt_masks, rcnn_train_cfg)
         return mask_targets
-- 
GitLab