From e3c1b8550c7b6607f68b7a89b84cc7aa4a777d85 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Thu, 6 Dec 2018 21:45:13 +0800
Subject: [PATCH] refactoring for sampler and assigner

---
 configs/cascade_mask_rcnn_r101_fpn_1x.py      |  24 +-
 configs/cascade_mask_rcnn_r50_fpn_1x.py       |  24 +-
 configs/cascade_rcnn_r101_fpn_1x.py           |  24 +-
 configs/cascade_rcnn_r50_fpn_1x.py            |  24 +-
 configs/fast_mask_rcnn_r101_fpn_1x.py         |   6 +-
 configs/fast_mask_rcnn_r50_fpn_1x.py          |   6 +-
 configs/fast_rcnn_r101_fpn_1x.py              |   6 +-
 configs/fast_rcnn_r50_fpn_1x.py               |   6 +-
 configs/faster_rcnn_r101_fpn_1x.py            |  12 +-
 configs/faster_rcnn_r50_fpn_1x.py             |  12 +-
 configs/mask_rcnn_r101_fpn_1x.py              |  12 +-
 configs/mask_rcnn_r50_fpn_1x.py               |  12 +-
 configs/retinanet_r101_fpn_1x.py              |   6 +-
 configs/retinanet_r50_fpn_1x.py               |   6 +-
 configs/rpn_r101_fpn_1x.py                    |   6 +-
 configs/rpn_r50_fpn_1x.py                     |   6 +-
 mmdet/core/anchor/anchor_target.py            |  14 +-
 mmdet/core/bbox/__init__.py                   |  18 +-
 mmdet/core/bbox/assign_sampling.py            |  35 +++
 mmdet/core/bbox/assigners/__init__.py         |   5 +
 mmdet/core/bbox/assigners/assign_result.py    |  19 ++
 mmdet/core/bbox/assigners/base_assigner.py    |   8 +
 .../max_iou_assigner.py}                      |  24 +-
 mmdet/core/bbox/samplers/__init__.py          |  13 +
 mmdet/core/bbox/samplers/base_sampler.py      |  64 +++++
 mmdet/core/bbox/samplers/combined_sampler.py  |  16 ++
 .../samplers/instance_balanced_pos_sampler.py |  41 ++++
 .../bbox/samplers/iou_balanced_neg_sampler.py |  62 +++++
 mmdet/core/bbox/samplers/pseudo_sampler.py    |  26 ++
 mmdet/core/bbox/samplers/random_sampler.py    |  55 +++++
 mmdet/core/bbox/samplers/sampling_result.py   |  24 ++
 mmdet/core/bbox/sampling.py                   | 227 ------------------
 32 files changed, 488 insertions(+), 355 deletions(-)
 create mode 100644 mmdet/core/bbox/assign_sampling.py
 create mode 100644 mmdet/core/bbox/assigners/__init__.py
 create mode 100644 mmdet/core/bbox/assigners/assign_result.py
 create mode 100644 mmdet/core/bbox/assigners/base_assigner.py
 rename mmdet/core/bbox/{assignment.py => assigners/max_iou_assigner.py} (88%)
 create mode 100644 mmdet/core/bbox/samplers/__init__.py
 create mode 100644 mmdet/core/bbox/samplers/base_sampler.py
 create mode 100644 mmdet/core/bbox/samplers/combined_sampler.py
 create mode 100644 mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
 create mode 100644 mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
 create mode 100644 mmdet/core/bbox/samplers/pseudo_sampler.py
 create mode 100644 mmdet/core/bbox/samplers/random_sampler.py
 create mode 100644 mmdet/core/bbox/samplers/sampling_result.py
 delete mode 100644 mmdet/core/bbox/sampling.py

diff --git a/configs/cascade_mask_rcnn_r101_fpn_1x.py b/configs/cascade_mask_rcnn_r101_fpn_1x.py
index 2cb289e..4466613 100644
--- a/configs/cascade_mask_rcnn_r101_fpn_1x.py
+++ b/configs/cascade_mask_rcnn_r101_fpn_1x.py
@@ -77,17 +77,17 @@ model = dict(
 train_cfg = dict(
     rpn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.7,
             neg_iou_thr=0.3,
             min_pos_iou=0.3,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=256,
             pos_fraction=0.5,
             neg_pos_ub=-1,
-            add_gt_as_proposals=False,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=False),
         allowed_border=0,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
@@ -95,49 +95,49 @@ train_cfg = dict(
     rcnn=[
         dict(
             assigner=dict(
+                type='MaxIoUAssigner',
                 pos_iou_thr=0.5,
                 neg_iou_thr=0.5,
                 min_pos_iou=0.5,
                 ignore_iof_thr=-1),
             sampler=dict(
+                type='RandomSampler',
                 num=512,
                 pos_fraction=0.25,
                 neg_pos_ub=-1,
-                add_gt_as_proposals=True,
-                pos_balance_sampling=False,
-                neg_balance_thr=0),
+                add_gt_as_proposals=True),
             mask_size=28,
             pos_weight=-1,
             debug=False),
         dict(
             assigner=dict(
+                type='MaxIoUAssigner',
                 pos_iou_thr=0.6,
                 neg_iou_thr=0.6,
                 min_pos_iou=0.6,
                 ignore_iof_thr=-1),
             sampler=dict(
+                type='RandomSampler',
                 num=512,
                 pos_fraction=0.25,
                 neg_pos_ub=-1,
-                add_gt_as_proposals=True,
-                pos_balance_sampling=False,
-                neg_balance_thr=0),
+                add_gt_as_proposals=True),
             mask_size=28,
             pos_weight=-1,
             debug=False),
         dict(
             assigner=dict(
+                type='MaxIoUAssigner',
                 pos_iou_thr=0.7,
                 neg_iou_thr=0.7,
                 min_pos_iou=0.7,
                 ignore_iof_thr=-1),
             sampler=dict(
+                type='RandomSampler',
                 num=512,
                 pos_fraction=0.25,
                 neg_pos_ub=-1,
-                add_gt_as_proposals=True,
-                pos_balance_sampling=False,
-                neg_balance_thr=0),
+                add_gt_as_proposals=True),
             mask_size=28,
             pos_weight=-1,
             debug=False)
diff --git a/configs/cascade_mask_rcnn_r50_fpn_1x.py b/configs/cascade_mask_rcnn_r50_fpn_1x.py
index 538b468..af39dc6 100644
--- a/configs/cascade_mask_rcnn_r50_fpn_1x.py
+++ b/configs/cascade_mask_rcnn_r50_fpn_1x.py
@@ -77,17 +77,17 @@ model = dict(
 train_cfg = dict(
     rpn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.7,
             neg_iou_thr=0.3,
             min_pos_iou=0.3,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=256,
             pos_fraction=0.5,
             neg_pos_ub=-1,
-            add_gt_as_proposals=False,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=False),
         allowed_border=0,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
@@ -95,49 +95,49 @@ train_cfg = dict(
     rcnn=[
         dict(
             assigner=dict(
+                type='MaxIoUAssigner',
                 pos_iou_thr=0.5,
                 neg_iou_thr=0.5,
                 min_pos_iou=0.5,
                 ignore_iof_thr=-1),
             sampler=dict(
+                type='RandomSampler',
                 num=512,
                 pos_fraction=0.25,
                 neg_pos_ub=-1,
-                add_gt_as_proposals=True,
-                pos_balance_sampling=False,
-                neg_balance_thr=0),
+                add_gt_as_proposals=True),
             mask_size=28,
             pos_weight=-1,
             debug=False),
         dict(
             assigner=dict(
+                type='MaxIoUAssigner',
                 pos_iou_thr=0.6,
                 neg_iou_thr=0.6,
                 min_pos_iou=0.6,
                 ignore_iof_thr=-1),
             sampler=dict(
+                type='RandomSampler',
                 num=512,
                 pos_fraction=0.25,
                 neg_pos_ub=-1,
-                add_gt_as_proposals=True,
-                pos_balance_sampling=False,
-                neg_balance_thr=0),
+                add_gt_as_proposals=True),
             mask_size=28,
             pos_weight=-1,
             debug=False),
         dict(
             assigner=dict(
+                type='MaxIoUAssigner',
                 pos_iou_thr=0.7,
                 neg_iou_thr=0.7,
                 min_pos_iou=0.7,
                 ignore_iof_thr=-1),
             sampler=dict(
+                type='RandomSampler',
                 num=512,
                 pos_fraction=0.25,
                 neg_pos_ub=-1,
-                add_gt_as_proposals=True,
-                pos_balance_sampling=False,
-                neg_balance_thr=0),
+                add_gt_as_proposals=True),
             mask_size=28,
             pos_weight=-1,
             debug=False)
diff --git a/configs/cascade_rcnn_r101_fpn_1x.py b/configs/cascade_rcnn_r101_fpn_1x.py
index e30370b..ccacc01 100644
--- a/configs/cascade_rcnn_r101_fpn_1x.py
+++ b/configs/cascade_rcnn_r101_fpn_1x.py
@@ -66,17 +66,17 @@ model = dict(
 train_cfg = dict(
     rpn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.7,
             neg_iou_thr=0.3,
             min_pos_iou=0.3,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=256,
             pos_fraction=0.5,
             neg_pos_ub=-1,
-            add_gt_as_proposals=False,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=False),
         allowed_border=0,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
@@ -84,47 +84,47 @@ train_cfg = dict(
     rcnn=[
         dict(
             assigner=dict(
+                type='MaxIoUAssigner',
                 pos_iou_thr=0.5,
                 neg_iou_thr=0.5,
                 min_pos_iou=0.5,
                 ignore_iof_thr=-1),
             sampler=dict(
+                type='RandomSampler',
                 num=512,
                 pos_fraction=0.25,
                 neg_pos_ub=-1,
-                add_gt_as_proposals=True,
-                pos_balance_sampling=False,
-                neg_balance_thr=0),
+                add_gt_as_proposals=True),
             pos_weight=-1,
             debug=False),
         dict(
             assigner=dict(
+                type='MaxIoUAssigner',
                 pos_iou_thr=0.6,
                 neg_iou_thr=0.6,
                 min_pos_iou=0.6,
                 ignore_iof_thr=-1),
             sampler=dict(
+                type='RandomSampler',
                 num=512,
                 pos_fraction=0.25,
                 neg_pos_ub=-1,
-                add_gt_as_proposals=True,
-                pos_balance_sampling=False,
-                neg_balance_thr=0),
+                add_gt_as_proposals=True),
             pos_weight=-1,
             debug=False),
         dict(
             assigner=dict(
+                type='MaxIoUAssigner',
                 pos_iou_thr=0.7,
                 neg_iou_thr=0.7,
                 min_pos_iou=0.7,
                 ignore_iof_thr=-1),
             sampler=dict(
+                type='RandomSampler',
                 num=512,
                 pos_fraction=0.25,
                 neg_pos_ub=-1,
-                add_gt_as_proposals=True,
-                pos_balance_sampling=False,
-                neg_balance_thr=0),
+                add_gt_as_proposals=True),
             pos_weight=-1,
             debug=False)
     ],
diff --git a/configs/cascade_rcnn_r50_fpn_1x.py b/configs/cascade_rcnn_r50_fpn_1x.py
index 69a2e52..75d9edd 100644
--- a/configs/cascade_rcnn_r50_fpn_1x.py
+++ b/configs/cascade_rcnn_r50_fpn_1x.py
@@ -66,17 +66,17 @@ model = dict(
 train_cfg = dict(
     rpn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.7,
             neg_iou_thr=0.3,
             min_pos_iou=0.3,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=256,
             pos_fraction=0.5,
             neg_pos_ub=-1,
-            add_gt_as_proposals=False,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=False),
         allowed_border=0,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
@@ -84,47 +84,47 @@ train_cfg = dict(
     rcnn=[
         dict(
             assigner=dict(
+                type='MaxIoUAssigner',
                 pos_iou_thr=0.5,
                 neg_iou_thr=0.5,
                 min_pos_iou=0.5,
                 ignore_iof_thr=-1),
             sampler=dict(
+                type='RandomSampler',
                 num=512,
                 pos_fraction=0.25,
                 neg_pos_ub=-1,
-                add_gt_as_proposals=True,
-                pos_balance_sampling=False,
-                neg_balance_thr=0),
+                add_gt_as_proposals=True),
             pos_weight=-1,
             debug=False),
         dict(
             assigner=dict(
+                type='MaxIoUAssigner',
                 pos_iou_thr=0.6,
                 neg_iou_thr=0.6,
                 min_pos_iou=0.6,
                 ignore_iof_thr=-1),
             sampler=dict(
+                type='RandomSampler',
                 num=512,
                 pos_fraction=0.25,
                 neg_pos_ub=-1,
-                add_gt_as_proposals=True,
-                pos_balance_sampling=False,
-                neg_balance_thr=0),
+                add_gt_as_proposals=True),
             pos_weight=-1,
             debug=False),
         dict(
             assigner=dict(
+                type='MaxIoUAssigner',
                 pos_iou_thr=0.7,
                 neg_iou_thr=0.7,
                 min_pos_iou=0.7,
                 ignore_iof_thr=-1),
             sampler=dict(
+                type='RandomSampler',
                 num=512,
                 pos_fraction=0.25,
                 neg_pos_ub=-1,
-                add_gt_as_proposals=True,
-                pos_balance_sampling=False,
-                neg_balance_thr=0),
+                add_gt_as_proposals=True),
             pos_weight=-1,
             debug=False)
     ],
diff --git a/configs/fast_mask_rcnn_r101_fpn_1x.py b/configs/fast_mask_rcnn_r101_fpn_1x.py
index 342d775..fa64d6f 100644
--- a/configs/fast_mask_rcnn_r101_fpn_1x.py
+++ b/configs/fast_mask_rcnn_r101_fpn_1x.py
@@ -44,17 +44,17 @@ model = dict(
 train_cfg = dict(
     rcnn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.5,
             neg_iou_thr=0.5,
             min_pos_iou=0.5,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=512,
             pos_fraction=0.25,
             neg_pos_ub=-1,
-            add_gt_as_proposals=True,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=True),
         mask_size=28,
         pos_weight=-1,
         debug=False))
diff --git a/configs/fast_mask_rcnn_r50_fpn_1x.py b/configs/fast_mask_rcnn_r50_fpn_1x.py
index 8863ba6..2005100 100644
--- a/configs/fast_mask_rcnn_r50_fpn_1x.py
+++ b/configs/fast_mask_rcnn_r50_fpn_1x.py
@@ -44,17 +44,17 @@ model = dict(
 train_cfg = dict(
     rcnn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.5,
             neg_iou_thr=0.5,
             min_pos_iou=0.5,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=512,
             pos_fraction=0.25,
             neg_pos_ub=-1,
-            add_gt_as_proposals=True,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=True),
         mask_size=28,
         pos_weight=-1,
         debug=False))
diff --git a/configs/fast_rcnn_r101_fpn_1x.py b/configs/fast_rcnn_r101_fpn_1x.py
index 66b7e8c..c61b74f 100644
--- a/configs/fast_rcnn_r101_fpn_1x.py
+++ b/configs/fast_rcnn_r101_fpn_1x.py
@@ -33,17 +33,17 @@ model = dict(
 train_cfg = dict(
     rcnn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.5,
             neg_iou_thr=0.5,
             min_pos_iou=0.5,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=512,
             pos_fraction=0.25,
             neg_pos_ub=-1,
-            add_gt_as_proposals=True,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=True),
         pos_weight=-1,
         debug=False))
 test_cfg = dict(
diff --git a/configs/fast_rcnn_r50_fpn_1x.py b/configs/fast_rcnn_r50_fpn_1x.py
index 57394bc..542e2dd 100644
--- a/configs/fast_rcnn_r50_fpn_1x.py
+++ b/configs/fast_rcnn_r50_fpn_1x.py
@@ -33,17 +33,17 @@ model = dict(
 train_cfg = dict(
     rcnn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.5,
             neg_iou_thr=0.5,
             min_pos_iou=0.5,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=512,
             pos_fraction=0.25,
             neg_pos_ub=-1,
-            add_gt_as_proposals=True,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=True),
         pos_weight=-1,
         debug=False))
 test_cfg = dict(
diff --git a/configs/faster_rcnn_r101_fpn_1x.py b/configs/faster_rcnn_r101_fpn_1x.py
index 49813d1..2ff48c5 100644
--- a/configs/faster_rcnn_r101_fpn_1x.py
+++ b/configs/faster_rcnn_r101_fpn_1x.py
@@ -43,34 +43,34 @@ model = dict(
 train_cfg = dict(
     rpn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.7,
             neg_iou_thr=0.3,
             min_pos_iou=0.3,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=256,
             pos_fraction=0.5,
             neg_pos_ub=-1,
-            add_gt_as_proposals=False,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=False),
         allowed_border=0,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
         debug=False),
     rcnn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.5,
             neg_iou_thr=0.5,
             min_pos_iou=0.5,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=512,
             pos_fraction=0.25,
             neg_pos_ub=-1,
-            add_gt_as_proposals=True,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=True),
         pos_weight=-1,
         debug=False))
 test_cfg = dict(
diff --git a/configs/faster_rcnn_r50_fpn_1x.py b/configs/faster_rcnn_r50_fpn_1x.py
index 97899af..e88e348 100644
--- a/configs/faster_rcnn_r50_fpn_1x.py
+++ b/configs/faster_rcnn_r50_fpn_1x.py
@@ -43,34 +43,34 @@ model = dict(
 train_cfg = dict(
     rpn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.7,
             neg_iou_thr=0.3,
             min_pos_iou=0.3,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=256,
             pos_fraction=0.5,
             neg_pos_ub=-1,
-            add_gt_as_proposals=False,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=False),
         allowed_border=0,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
         debug=False),
     rcnn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.5,
             neg_iou_thr=0.5,
             min_pos_iou=0.5,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=512,
             pos_fraction=0.25,
             neg_pos_ub=-1,
-            add_gt_as_proposals=True,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=True),
         pos_weight=-1,
         debug=False))
 test_cfg = dict(
diff --git a/configs/mask_rcnn_r101_fpn_1x.py b/configs/mask_rcnn_r101_fpn_1x.py
index 65a93e2..3675a80 100644
--- a/configs/mask_rcnn_r101_fpn_1x.py
+++ b/configs/mask_rcnn_r101_fpn_1x.py
@@ -54,34 +54,34 @@ model = dict(
 train_cfg = dict(
     rpn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.7,
             neg_iou_thr=0.3,
             min_pos_iou=0.3,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=256,
             pos_fraction=0.5,
             neg_pos_ub=-1,
-            add_gt_as_proposals=False,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=False),
         allowed_border=0,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
         debug=False),
     rcnn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.5,
             neg_iou_thr=0.5,
             min_pos_iou=0.5,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=512,
             pos_fraction=0.25,
             neg_pos_ub=-1,
-            add_gt_as_proposals=True,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=True),
         mask_size=28,
         pos_weight=-1,
         debug=False))
diff --git a/configs/mask_rcnn_r50_fpn_1x.py b/configs/mask_rcnn_r50_fpn_1x.py
index c2ef8fa..364944f 100644
--- a/configs/mask_rcnn_r50_fpn_1x.py
+++ b/configs/mask_rcnn_r50_fpn_1x.py
@@ -54,34 +54,34 @@ model = dict(
 train_cfg = dict(
     rpn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.7,
             neg_iou_thr=0.3,
             min_pos_iou=0.3,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=256,
             pos_fraction=0.5,
             neg_pos_ub=-1,
-            add_gt_as_proposals=False,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=False),
         allowed_border=0,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
         debug=False),
     rcnn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.5,
             neg_iou_thr=0.5,
             min_pos_iou=0.5,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=512,
             pos_fraction=0.25,
             neg_pos_ub=-1,
-            add_gt_as_proposals=True,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=True),
         mask_size=28,
         pos_weight=-1,
         debug=False))
diff --git a/configs/retinanet_r101_fpn_1x.py b/configs/retinanet_r101_fpn_1x.py
index fd4dd9a..e07d98a 100644
--- a/configs/retinanet_r101_fpn_1x.py
+++ b/configs/retinanet_r101_fpn_1x.py
@@ -31,7 +31,11 @@ model = dict(
 # training and testing settings
 train_cfg = dict(
     assigner=dict(
-        pos_iou_thr=0.5, neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1),
+        type='MaxIoUAssigner',
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.4,
+        min_pos_iou=0,
+        ignore_iof_thr=-1),
     smoothl1_beta=0.11,
     gamma=2.0,
     alpha=0.25,
diff --git a/configs/retinanet_r50_fpn_1x.py b/configs/retinanet_r50_fpn_1x.py
index 77f67de..2840c06 100644
--- a/configs/retinanet_r50_fpn_1x.py
+++ b/configs/retinanet_r50_fpn_1x.py
@@ -31,7 +31,11 @@ model = dict(
 # training and testing settings
 train_cfg = dict(
     assigner=dict(
-        pos_iou_thr=0.5, neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1),
+        type='MaxIoUAssigner',
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.4,
+        min_pos_iou=0,
+        ignore_iof_thr=-1),
     smoothl1_beta=0.11,
     gamma=2.0,
     alpha=0.25,
diff --git a/configs/rpn_r101_fpn_1x.py b/configs/rpn_r101_fpn_1x.py
index 90f260d..450215e 100644
--- a/configs/rpn_r101_fpn_1x.py
+++ b/configs/rpn_r101_fpn_1x.py
@@ -28,17 +28,17 @@ model = dict(
 train_cfg = dict(
     rpn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.7,
             neg_iou_thr=0.3,
             min_pos_iou=0.3,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=256,
             pos_fraction=0.5,
             neg_pos_ub=-1,
-            add_gt_as_proposals=False,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=False),
         allowed_border=0,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
diff --git a/configs/rpn_r50_fpn_1x.py b/configs/rpn_r50_fpn_1x.py
index 8e2b402..3af2649 100644
--- a/configs/rpn_r50_fpn_1x.py
+++ b/configs/rpn_r50_fpn_1x.py
@@ -28,17 +28,17 @@ model = dict(
 train_cfg = dict(
     rpn=dict(
         assigner=dict(
+            type='MaxIoUAssigner',
             pos_iou_thr=0.7,
             neg_iou_thr=0.3,
             min_pos_iou=0.3,
             ignore_iof_thr=-1),
         sampler=dict(
+            type='RandomSampler',
             num=256,
             pos_fraction=0.5,
             neg_pos_ub=-1,
-            add_gt_as_proposals=False,
-            pos_balance_sampling=False,
-            neg_balance_thr=0),
+            add_gt_as_proposals=False),
         allowed_border=0,
         pos_weight=-1,
         smoothl1_beta=1 / 9.0,
diff --git a/mmdet/core/anchor/anchor_target.py b/mmdet/core/anchor/anchor_target.py
index 41086cf..5fcdb83 100644
--- a/mmdet/core/anchor/anchor_target.py
+++ b/mmdet/core/anchor/anchor_target.py
@@ -1,6 +1,6 @@
 import torch
 
-from ..bbox import assign_and_sample, BBoxAssigner, SamplingResult, bbox2delta
+from ..bbox import assign_and_sample, build_assigner, PseudoSampler, bbox2delta
 from ..utils import multi_apply
 
 
@@ -107,16 +107,12 @@ def anchor_target_single(flat_anchors,
         assign_result, sampling_result = assign_and_sample(
             anchors, gt_bboxes, None, None, cfg)
     else:
-        bbox_assigner = BBoxAssigner(**cfg.assigner)
+        bbox_assigner = build_assigner(cfg.assigner)
         assign_result = bbox_assigner.assign(anchors, gt_bboxes, None,
                                              gt_labels)
-        pos_inds = torch.nonzero(
-            assign_result.gt_inds > 0).squeeze(-1).unique()
-        neg_inds = torch.nonzero(
-            assign_result.gt_inds == 0).squeeze(-1).unique()
-        gt_flags = anchors.new_zeros(anchors.shape[0], dtype=torch.uint8)
-        sampling_result = SamplingResult(pos_inds, neg_inds, anchors,
-                                         gt_bboxes, assign_result, gt_flags)
+        bbox_sampler = PseudoSampler()
+        sampling_result = bbox_sampler.sample(assign_result, anchors,
+                                              gt_bboxes)
 
     num_valid_anchors = anchors.shape[0]
     bbox_targets = torch.zeros_like(anchors)
diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py
index 2ed869f..496bd7a 100644
--- a/mmdet/core/bbox/__init__.py
+++ b/mmdet/core/bbox/__init__.py
@@ -1,14 +1,18 @@
 from .geometry import bbox_overlaps
-from .assignment import BBoxAssigner, AssignResult
-from .sampling import (BBoxSampler, SamplingResult, assign_and_sample,
-                       random_choice)
+from .assigners import BaseAssigner, MaxIoUAssigner, AssignResult
+from .samplers import (BaseSampler, PseudoSampler, RandomSampler,
+                       InstanceBalancedPosSampler, IoUBalancedNegSampler,
+                       CombinedSampler, SamplingResult)
+from .assign_sampling import build_assigner, build_sampler, assign_and_sample
 from .transforms import (bbox2delta, delta2bbox, bbox_flip, bbox_mapping,
                          bbox_mapping_back, bbox2roi, roi2bbox, bbox2result)
 from .bbox_target import bbox_target
 
 __all__ = [
-    'bbox_overlaps', 'BBoxAssigner', 'AssignResult', 'BBoxSampler',
-    'SamplingResult', 'assign_and_sample', 'random_choice', 'bbox2delta',
-    'delta2bbox', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi',
-    'roi2bbox', 'bbox2result', 'bbox_target'
+    'bbox_overlaps', 'BaseAssigner', 'MaxIoUAssigner', 'AssignResult',
+    'BaseSampler', 'PseudoSampler', 'RandomSampler',
+    'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
+    'SamplingResult', 'build_assigner', 'build_sampler', 'assign_and_sample',
+    'bbox2delta', 'delta2bbox', 'bbox_flip', 'bbox_mapping',
+    'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result', 'bbox_target'
 ]
diff --git a/mmdet/core/bbox/assign_sampling.py b/mmdet/core/bbox/assign_sampling.py
new file mode 100644
index 0000000..b1b199a
--- /dev/null
+++ b/mmdet/core/bbox/assign_sampling.py
@@ -0,0 +1,35 @@
+import mmcv
+
+from . import assigners, samplers
+
+
+def build_assigner(cfg, default_args=None):
+    if isinstance(cfg, assigners.BaseAssigner):
+        return cfg
+    elif isinstance(cfg, dict):
+        return mmcv.runner.obj_from_dict(
+            cfg, assigners, default_args=default_args)
+    else:
+        raise TypeError('Invalid type {} for building a sampler'.format(
+            type(cfg)))
+
+
+def build_sampler(cfg, default_args=None):
+    if isinstance(cfg, samplers.BaseSampler):
+        return cfg
+    elif isinstance(cfg, dict):
+        return mmcv.runner.obj_from_dict(
+            cfg, samplers, default_args=default_args)
+    else:
+        raise TypeError('Invalid type {} for building a sampler'.format(
+            type(cfg)))
+
+
+def assign_and_sample(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg):
+    bbox_assigner = build_assigner(cfg.assigner)
+    bbox_sampler = build_sampler(cfg.sampler)
+    assign_result = bbox_assigner.assign(bboxes, gt_bboxes, gt_bboxes_ignore,
+                                         gt_labels)
+    sampling_result = bbox_sampler.sample(assign_result, bboxes, gt_bboxes,
+                                          gt_labels)
+    return assign_result, sampling_result
diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py
new file mode 100644
index 0000000..40a89e9
--- /dev/null
+++ b/mmdet/core/bbox/assigners/__init__.py
@@ -0,0 +1,5 @@
+from .base_assigner import BaseAssigner
+from .max_iou_assigner import MaxIoUAssigner
+from .assign_result import AssignResult
+
+__all__ = ['BaseAssigner', 'MaxIoUAssigner', 'AssignResult']
diff --git a/mmdet/core/bbox/assigners/assign_result.py b/mmdet/core/bbox/assigners/assign_result.py
new file mode 100644
index 0000000..33c761d
--- /dev/null
+++ b/mmdet/core/bbox/assigners/assign_result.py
@@ -0,0 +1,19 @@
+import torch
+
+
+class AssignResult(object):
+
+    def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
+        self.num_gts = num_gts
+        self.gt_inds = gt_inds
+        self.max_overlaps = max_overlaps
+        self.labels = labels
+
+    def add_gt_(self, gt_labels):
+        self_inds = torch.arange(
+            1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
+        self.gt_inds = torch.cat([self_inds, self.gt_inds])
+        self.max_overlaps = torch.cat(
+            [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
+        if self.labels is not None:
+            self.labels = torch.cat([gt_labels, self.labels])
diff --git a/mmdet/core/bbox/assigners/base_assigner.py b/mmdet/core/bbox/assigners/base_assigner.py
new file mode 100644
index 0000000..7bd02dc
--- /dev/null
+++ b/mmdet/core/bbox/assigners/base_assigner.py
@@ -0,0 +1,8 @@
+from abc import ABCMeta, abstractmethod
+
+
+class BaseAssigner(metaclass=ABCMeta):
+
+    @abstractmethod
+    def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+        pass
diff --git a/mmdet/core/bbox/assignment.py b/mmdet/core/bbox/assigners/max_iou_assigner.py
similarity index 88%
rename from mmdet/core/bbox/assignment.py
rename to mmdet/core/bbox/assigners/max_iou_assigner.py
index 62233af..c43db07 100644
--- a/mmdet/core/bbox/assignment.py
+++ b/mmdet/core/bbox/assigners/max_iou_assigner.py
@@ -1,9 +1,11 @@
 import torch
 
-from .geometry import bbox_overlaps
+from .base_assigner import BaseAssigner
+from .assign_result import AssignResult
+from ..geometry import bbox_overlaps
 
 
-class BBoxAssigner(object):
+class MaxIoUAssigner(BaseAssigner):
     """Assign a corresponding gt bbox or background to each bbox.
 
     Each proposals will be assigned with `-1`, `0`, or a positive integer
@@ -135,21 +137,3 @@ class BBoxAssigner(object):
 
         return AssignResult(
             num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
-
-
-class AssignResult(object):
-
-    def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
-        self.num_gts = num_gts
-        self.gt_inds = gt_inds
-        self.max_overlaps = max_overlaps
-        self.labels = labels
-
-    def add_gt_(self, gt_labels):
-        self_inds = torch.arange(
-            1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
-        self.gt_inds = torch.cat([self_inds, self.gt_inds])
-        self.max_overlaps = torch.cat(
-            [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
-        if self.labels is not None:
-            self.labels = torch.cat([gt_labels, self.labels])
diff --git a/mmdet/core/bbox/samplers/__init__.py b/mmdet/core/bbox/samplers/__init__.py
new file mode 100644
index 0000000..8a1e677
--- /dev/null
+++ b/mmdet/core/bbox/samplers/__init__.py
@@ -0,0 +1,13 @@
+from .base_sampler import BaseSampler
+from .pseudo_sampler import PseudoSampler
+from .random_sampler import RandomSampler
+from .instance_balanced_pos_sampler import InstanceBalancedPosSampler
+from .iou_balanced_neg_sampler import IoUBalancedNegSampler
+from .combined_sampler import CombinedSampler
+from .sampling_result import SamplingResult
+
+__all__ = [
+    'BaseSampler', 'PseudoSampler', 'RandomSampler',
+    'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
+    'SamplingResult'
+]
diff --git a/mmdet/core/bbox/samplers/base_sampler.py b/mmdet/core/bbox/samplers/base_sampler.py
new file mode 100644
index 0000000..6ac300f
--- /dev/null
+++ b/mmdet/core/bbox/samplers/base_sampler.py
@@ -0,0 +1,64 @@
+from abc import ABCMeta, abstractmethod
+
+import torch
+
+from .sampling_result import SamplingResult
+
+
+class BaseSampler(metaclass=ABCMeta):
+
+    def __init__(self):
+        self.pos_sampler = self
+        self.neg_sampler = self
+
+    @abstractmethod
+    def _sample_pos(self, assign_result, num_expected):
+        pass
+
+    @abstractmethod
+    def _sample_neg(self, assign_result, num_expected):
+        pass
+
+    def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None):
+        """Sample positive and negative bboxes.
+
+        This is a simple implementation of bbox sampling given candidates,
+        assigning results and ground truth bboxes.
+
+        Args:
+            assign_result (:obj:`AssignResult`): Bbox assigning results.
+            bboxes (Tensor): Boxes to be sampled from.
+            gt_bboxes (Tensor): Ground truth bboxes.
+            gt_labels (Tensor, optional): Class labels of ground truth bboxes.
+
+        Returns:
+            :obj:`SamplingResult`: Sampling result.
+        """
+        bboxes = bboxes[:, :4]
+
+        gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
+        if self.add_gt_as_proposals:
+            bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
+            assign_result.add_gt_(gt_labels)
+            gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
+            gt_flags = torch.cat([gt_ones, gt_flags])
+
+        num_expected_pos = int(self.num * self.pos_fraction)
+        pos_inds = self.pos_sampler._sample_pos(assign_result,
+                                                num_expected_pos)
+        # We found that sampled indices have duplicated items occasionally.
+        # (may be a bug of PyTorch)
+        pos_inds = pos_inds.unique()
+        num_sampled_pos = pos_inds.numel()
+        num_expected_neg = self.num - num_sampled_pos
+        if self.neg_pos_ub >= 0:
+            _pos = max(1, num_sampled_pos)
+            neg_upper_bound = int(self.neg_pos_ub * _pos)
+            if num_expected_neg > neg_upper_bound:
+                num_expected_neg = neg_upper_bound
+        neg_inds = self.neg_sampler._sample_neg(assign_result,
+                                                num_expected_neg)
+        neg_inds = neg_inds.unique()
+
+        return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
+                              assign_result, gt_flags)
diff --git a/mmdet/core/bbox/samplers/combined_sampler.py b/mmdet/core/bbox/samplers/combined_sampler.py
new file mode 100644
index 0000000..578abe3
--- /dev/null
+++ b/mmdet/core/bbox/samplers/combined_sampler.py
@@ -0,0 +1,16 @@
+from mmcv.runner import obj_from_dict
+
+from .random_sampler import RandomSampler
+from ..assign_sampling import build_sampler
+
+
+class CombinedSampler(RandomSampler):
+
+    def __init__(self, num, pos_fraction, pos_sampler, neg_sampler, **kwargs):
+        super(CombinedSampler, self).__init__(num, pos_fraction, **kwargs)
+        default_args = dict(num=num, pos_fraction=pos_fraction)
+        default_args.update(kwargs)
+        self.pos_sampler = build_sampler(
+            pos_sampler, default_args=default_args)
+        self.neg_sampler = build_sampler(
+            neg_sampler, default_args=default_args)
diff --git a/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py b/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
new file mode 100644
index 0000000..1d65029
--- /dev/null
+++ b/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
@@ -0,0 +1,41 @@
+import numpy as np
+import torch
+
+from .random_sampler import RandomSampler
+
+
+class InstanceBalancedPosSampler(RandomSampler):
+
+    def _sample_pos(self, assign_result, num_expected):
+        pos_inds = torch.nonzero(assign_result.gt_inds > 0)
+        if pos_inds.numel() != 0:
+            pos_inds = pos_inds.squeeze(1)
+        if pos_inds.numel() <= num_expected:
+            return pos_inds
+        else:
+            unique_gt_inds = assign_result.gt_inds[pos_inds].unique()
+            num_gts = len(unique_gt_inds)
+            num_per_gt = int(round(num_expected / float(num_gts)) + 1)
+            sampled_inds = []
+            for i in unique_gt_inds:
+                inds = torch.nonzero(assign_result.gt_inds == i.item())
+                if inds.numel() != 0:
+                    inds = inds.squeeze(1)
+                else:
+                    continue
+                if len(inds) > num_per_gt:
+                    inds = self.random_choice(inds, num_per_gt)
+                sampled_inds.append(inds)
+            sampled_inds = torch.cat(sampled_inds)
+            if len(sampled_inds) < num_expected:
+                num_extra = num_expected - len(sampled_inds)
+                extra_inds = np.array(
+                    list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
+                if len(extra_inds) > num_extra:
+                    extra_inds = self.random_choice(extra_inds, num_extra)
+                extra_inds = torch.from_numpy(extra_inds).to(
+                    assign_result.gt_inds.device).long()
+                sampled_inds = torch.cat([sampled_inds, extra_inds])
+            elif len(sampled_inds) > num_expected:
+                sampled_inds = self.random_choice(sampled_inds, num_expected)
+            return sampled_inds
diff --git a/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
new file mode 100644
index 0000000..8e593fa
--- /dev/null
+++ b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
@@ -0,0 +1,62 @@
+import numpy as np
+import torch
+
+from .random_sampler import RandomSampler
+
+
+class IoUBalancedNegSampler(RandomSampler):
+
+    def __init__(self,
+                 num,
+                 pos_fraction,
+                 hard_thr=0.1,
+                 hard_fraction=0.5,
+                 **kwargs):
+        super(IoUBalancedNegSampler, self).__init__(num, pos_fraction,
+                                                    **kwargs)
+        assert hard_thr > 0
+        assert 0 < hard_fraction < 1
+        self.hard_thr = hard_thr
+        self.hard_fraction = hard_fraction
+
+    def _sample_neg(self, assign_result, num_expected):
+        neg_inds = torch.nonzero(assign_result.gt_inds == 0)
+        if neg_inds.numel() != 0:
+            neg_inds = neg_inds.squeeze(1)
+        if len(neg_inds) <= num_expected:
+            return neg_inds
+        else:
+            max_overlaps = assign_result.max_overlaps.cpu().numpy()
+            # balance sampling for negative samples
+            neg_set = set(neg_inds.cpu().numpy())
+            easy_set = set(
+                np.where(
+                    np.logical_and(max_overlaps >= 0,
+                                   max_overlaps < self.hard_thr))[0])
+            hard_set = set(np.where(max_overlaps >= self.hard_thr)[0])
+            easy_neg_inds = list(easy_set & neg_set)
+            hard_neg_inds = list(hard_set & neg_set)
+
+            num_expected_hard = int(num_expected * self.hard_fraction)
+            if len(hard_neg_inds) > num_expected_hard:
+                sampled_hard_inds = self.random_choice(hard_neg_inds,
+                                                       num_expected_hard)
+            else:
+                sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int)
+            num_expected_easy = num_expected - len(sampled_hard_inds)
+            if len(easy_neg_inds) > num_expected_easy:
+                sampled_easy_inds = self.random_choice(easy_neg_inds,
+                                                       num_expected_easy)
+            else:
+                sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int)
+            sampled_inds = np.concatenate((sampled_easy_inds,
+                                           sampled_hard_inds))
+            if len(sampled_inds) < num_expected:
+                num_extra = num_expected - len(sampled_inds)
+                extra_inds = np.array(list(neg_set - set(sampled_inds)))
+                if len(extra_inds) > num_extra:
+                    extra_inds = self.random_choice(extra_inds, num_extra)
+                sampled_inds = np.concatenate((sampled_inds, extra_inds))
+            sampled_inds = torch.from_numpy(sampled_inds).long().to(
+                assign_result.gt_inds.device)
+            return sampled_inds
diff --git a/mmdet/core/bbox/samplers/pseudo_sampler.py b/mmdet/core/bbox/samplers/pseudo_sampler.py
new file mode 100644
index 0000000..6c7189c
--- /dev/null
+++ b/mmdet/core/bbox/samplers/pseudo_sampler.py
@@ -0,0 +1,26 @@
+import torch
+
+from .base_sampler import BaseSampler
+from .sampling_result import SamplingResult
+
+
+class PseudoSampler(BaseSampler):
+
+    def __init__(self):
+        pass
+
+    def _sample_pos(self):
+        raise NotImplementedError
+
+    def _sample_neg(self):
+        raise NotImplementedError
+
+    def sample(self, assign_result, bboxes, gt_bboxes):
+        pos_inds = torch.nonzero(
+            assign_result.gt_inds > 0).squeeze(-1).unique()
+        neg_inds = torch.nonzero(
+            assign_result.gt_inds == 0).squeeze(-1).unique()
+        gt_flags = bboxes.new_zeros(bboxes.shape[0], dtype=torch.uint8)
+        sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
+                                         assign_result, gt_flags)
+        return sampling_result
diff --git a/mmdet/core/bbox/samplers/random_sampler.py b/mmdet/core/bbox/samplers/random_sampler.py
new file mode 100644
index 0000000..bd84ca6
--- /dev/null
+++ b/mmdet/core/bbox/samplers/random_sampler.py
@@ -0,0 +1,55 @@
+import numpy as np
+import torch
+
+from .base_sampler import BaseSampler
+
+
+class RandomSampler(BaseSampler):
+
+    def __init__(self,
+                 num,
+                 pos_fraction,
+                 neg_pos_ub=-1,
+                 add_gt_as_proposals=True):
+        super(RandomSampler, self).__init__()
+        self.num = num
+        self.pos_fraction = pos_fraction
+        self.neg_pos_ub = neg_pos_ub
+        self.add_gt_as_proposals = add_gt_as_proposals
+
+    @staticmethod
+    def random_choice(gallery, num):
+        """Random select some elements from the gallery.
+
+        It seems that Pytorch's implementation is slower than numpy so we use
+        numpy to randperm the indices.
+        """
+        assert len(gallery) >= num
+        if isinstance(gallery, list):
+            gallery = np.array(gallery)
+        cands = np.arange(len(gallery))
+        np.random.shuffle(cands)
+        rand_inds = cands[:num]
+        if not isinstance(gallery, np.ndarray):
+            rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
+        return gallery[rand_inds]
+
+    def _sample_pos(self, assign_result, num_expected):
+        """Randomly sample some positive samples."""
+        pos_inds = torch.nonzero(assign_result.gt_inds > 0)
+        if pos_inds.numel() != 0:
+            pos_inds = pos_inds.squeeze(1)
+        if pos_inds.numel() <= num_expected:
+            return pos_inds
+        else:
+            return self.random_choice(pos_inds, num_expected)
+
+    def _sample_neg(self, assign_result, num_expected):
+        """Randomly sample some negative samples."""
+        neg_inds = torch.nonzero(assign_result.gt_inds == 0)
+        if neg_inds.numel() != 0:
+            neg_inds = neg_inds.squeeze(1)
+        if len(neg_inds) <= num_expected:
+            return neg_inds
+        else:
+            return self.random_choice(neg_inds, num_expected)
diff --git a/mmdet/core/bbox/samplers/sampling_result.py b/mmdet/core/bbox/samplers/sampling_result.py
new file mode 100644
index 0000000..696e650
--- /dev/null
+++ b/mmdet/core/bbox/samplers/sampling_result.py
@@ -0,0 +1,24 @@
+import torch
+
+
+class SamplingResult(object):
+
+    def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
+                 gt_flags):
+        self.pos_inds = pos_inds
+        self.neg_inds = neg_inds
+        self.pos_bboxes = bboxes[pos_inds]
+        self.neg_bboxes = bboxes[neg_inds]
+        self.pos_is_gt = gt_flags[pos_inds]
+
+        self.num_gts = gt_bboxes.shape[0]
+        self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
+        self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
+        if assign_result.labels is not None:
+            self.pos_gt_labels = assign_result.labels[pos_inds]
+        else:
+            self.pos_gt_labels = None
+
+    @property
+    def bboxes(self):
+        return torch.cat([self.pos_bboxes, self.neg_bboxes])
diff --git a/mmdet/core/bbox/sampling.py b/mmdet/core/bbox/sampling.py
deleted file mode 100644
index 63d2279..0000000
--- a/mmdet/core/bbox/sampling.py
+++ /dev/null
@@ -1,227 +0,0 @@
-import numpy as np
-import torch
-
-from .assignment import BBoxAssigner
-
-
-def random_choice(gallery, num):
-    """Random select some elements from the gallery.
-
-    It seems that Pytorch's implementation is slower than numpy so we use numpy
-    to randperm the indices.
-    """
-    assert len(gallery) >= num
-    if isinstance(gallery, list):
-        gallery = np.array(gallery)
-    cands = np.arange(len(gallery))
-    np.random.shuffle(cands)
-    rand_inds = cands[:num]
-    if not isinstance(gallery, np.ndarray):
-        rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
-    return gallery[rand_inds]
-
-
-def assign_and_sample(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg):
-    bbox_assigner = BBoxAssigner(**cfg.assigner)
-    bbox_sampler = BBoxSampler(**cfg.sampler)
-    assign_result = bbox_assigner.assign(bboxes, gt_bboxes, gt_bboxes_ignore,
-                                         gt_labels)
-    sampling_result = bbox_sampler.sample(assign_result, bboxes, gt_bboxes,
-                                          gt_labels)
-    return assign_result, sampling_result
-
-
-class BBoxSampler(object):
-    """Sample positive and negative bboxes given assigned results.
-
-    Args:
-        pos_fraction (float): Positive sample fraction.
-        neg_pos_ub (float): Negative/Positive upper bound.
-        pos_balance_sampling (bool): Whether to sample positive samples around
-            each gt bbox evenly.
-        neg_balance_thr (float, optional): IoU threshold for simple/hard
-            negative balance sampling.
-        neg_hard_fraction (float, optional): Fraction of hard negative samples
-            for negative balance sampling.
-    """
-
-    def __init__(self,
-                 num,
-                 pos_fraction,
-                 neg_pos_ub=-1,
-                 add_gt_as_proposals=True,
-                 pos_balance_sampling=False,
-                 neg_balance_thr=0,
-                 neg_hard_fraction=0.5):
-        self.num = num
-        self.pos_fraction = pos_fraction
-        self.neg_pos_ub = neg_pos_ub
-        self.add_gt_as_proposals = add_gt_as_proposals
-        self.pos_balance_sampling = pos_balance_sampling
-        self.neg_balance_thr = neg_balance_thr
-        self.neg_hard_fraction = neg_hard_fraction
-
-    def _sample_pos(self, assign_result, num_expected):
-        """Balance sampling for positive bboxes/anchors.
-
-        1. calculate average positive num for each gt: num_per_gt
-        2. sample at most num_per_gt positives for each gt
-        3. random sampling from rest anchors if not enough fg
-        """
-        pos_inds = torch.nonzero(assign_result.gt_inds > 0)
-        if pos_inds.numel() != 0:
-            pos_inds = pos_inds.squeeze(1)
-        if pos_inds.numel() <= num_expected:
-            return pos_inds
-        elif not self.pos_balance_sampling:
-            return random_choice(pos_inds, num_expected)
-        else:
-            unique_gt_inds = torch.unique(
-                assign_result.gt_inds[pos_inds].cpu())
-            num_gts = len(unique_gt_inds)
-            num_per_gt = int(round(num_expected / float(num_gts)) + 1)
-            sampled_inds = []
-            for i in unique_gt_inds:
-                inds = torch.nonzero(assign_result.gt_inds == i.item())
-                if inds.numel() != 0:
-                    inds = inds.squeeze(1)
-                else:
-                    continue
-                if len(inds) > num_per_gt:
-                    inds = random_choice(inds, num_per_gt)
-                sampled_inds.append(inds)
-            sampled_inds = torch.cat(sampled_inds)
-            if len(sampled_inds) < num_expected:
-                num_extra = num_expected - len(sampled_inds)
-                extra_inds = np.array(
-                    list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
-                if len(extra_inds) > num_extra:
-                    extra_inds = random_choice(extra_inds, num_extra)
-                extra_inds = torch.from_numpy(extra_inds).to(
-                    assign_result.gt_inds.device).long()
-                sampled_inds = torch.cat([sampled_inds, extra_inds])
-            elif len(sampled_inds) > num_expected:
-                sampled_inds = random_choice(sampled_inds, num_expected)
-            return sampled_inds
-
-    def _sample_neg(self, assign_result, num_expected):
-        """Balance sampling for negative bboxes/anchors.
-
-        Negative samples are split into 2 set: hard (balance_thr <= iou <
-        neg_iou_thr) and easy (iou < balance_thr). The sampling ratio is
-        controlled by `hard_fraction`.
-        """
-        neg_inds = torch.nonzero(assign_result.gt_inds == 0)
-        if neg_inds.numel() != 0:
-            neg_inds = neg_inds.squeeze(1)
-        if len(neg_inds) <= num_expected:
-            return neg_inds
-        elif self.neg_balance_thr <= 0:
-            # uniform sampling among all negative samples
-            return random_choice(neg_inds, num_expected)
-        else:
-            max_overlaps = assign_result.max_overlaps.cpu().numpy()
-            # balance sampling for negative samples
-            neg_set = set(neg_inds.cpu().numpy())
-            easy_set = set(
-                np.where(
-                    np.logical_and(max_overlaps >= 0,
-                                   max_overlaps < self.neg_balance_thr))[0])
-            hard_set = set(np.where(max_overlaps >= self.neg_balance_thr)[0])
-            easy_neg_inds = list(easy_set & neg_set)
-            hard_neg_inds = list(hard_set & neg_set)
-
-            num_expected_hard = int(num_expected * self.neg_hard_fraction)
-            if len(hard_neg_inds) > num_expected_hard:
-                sampled_hard_inds = random_choice(hard_neg_inds,
-                                                  num_expected_hard)
-            else:
-                sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int)
-            num_expected_easy = num_expected - len(sampled_hard_inds)
-            if len(easy_neg_inds) > num_expected_easy:
-                sampled_easy_inds = random_choice(easy_neg_inds,
-                                                  num_expected_easy)
-            else:
-                sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int)
-            sampled_inds = np.concatenate((sampled_easy_inds,
-                                           sampled_hard_inds))
-            if len(sampled_inds) < num_expected:
-                num_extra = num_expected - len(sampled_inds)
-                extra_inds = np.array(list(neg_set - set(sampled_inds)))
-                if len(extra_inds) > num_extra:
-                    extra_inds = random_choice(extra_inds, num_extra)
-                sampled_inds = np.concatenate((sampled_inds, extra_inds))
-            sampled_inds = torch.from_numpy(sampled_inds).long().to(
-                assign_result.gt_inds.device)
-            return sampled_inds
-
-    def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None):
-        """Sample positive and negative bboxes.
-
-        This is a simple implementation of bbox sampling given candidates,
-        assigning results and ground truth bboxes.
-
-        1. Assign gt to each bbox.
-        2. Add gt bboxes to the sampling pool (optional).
-        3. Perform positive and negative sampling.
-
-        Args:
-            assign_result (:obj:`AssignResult`): Bbox assigning results.
-            bboxes (Tensor): Boxes to be sampled from.
-            gt_bboxes (Tensor): Ground truth bboxes.
-            gt_labels (Tensor, optional): Class labels of ground truth bboxes.
-
-        Returns:
-            :obj:`SamplingResult`: Sampling result.
-        """
-        bboxes = bboxes[:, :4]
-
-        gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
-        if self.add_gt_as_proposals:
-            bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
-            assign_result.add_gt_(gt_labels)
-            gt_flags = torch.cat([
-                bboxes.new_ones((gt_bboxes.shape[0], ), dtype=torch.uint8),
-                gt_flags
-            ])
-
-        num_expected_pos = int(self.num * self.pos_fraction)
-        pos_inds = self._sample_pos(assign_result, num_expected_pos)
-        # We found that sampled indices have duplicated items occasionally.
-        # (mab be a bug of PyTorch)
-        pos_inds = pos_inds.unique()
-        num_sampled_pos = pos_inds.numel()
-        num_expected_neg = self.num - num_sampled_pos
-        if self.neg_pos_ub >= 0:
-            num_neg_max = int(self.neg_pos_ub *
-                              num_sampled_pos) if num_sampled_pos > 0 else int(
-                                  self.neg_pos_ub)
-            num_expected_neg = min(num_neg_max, num_expected_neg)
-        neg_inds = self._sample_neg(assign_result, num_expected_neg)
-        neg_inds = neg_inds.unique()
-
-        return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
-                              assign_result, gt_flags)
-
-
-class SamplingResult(object):
-
-    def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
-                 gt_flags):
-        self.pos_inds = pos_inds
-        self.neg_inds = neg_inds
-        self.pos_bboxes = bboxes[pos_inds]
-        self.neg_bboxes = bboxes[neg_inds]
-        self.pos_is_gt = gt_flags[pos_inds]
-
-        self.num_gts = gt_bboxes.shape[0]
-        self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
-        self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
-        if assign_result.labels is not None:
-            self.pos_gt_labels = assign_result.labels[pos_inds]
-        else:
-            self.pos_gt_labels = None
-
-    @property
-    def bboxes(self):
-        return torch.cat([self.pos_bboxes, self.neg_bboxes])
-- 
GitLab