From e3c1b8550c7b6607f68b7a89b84cc7aa4a777d85 Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Thu, 6 Dec 2018 21:45:13 +0800 Subject: [PATCH] refactoring for sampler and assigner --- configs/cascade_mask_rcnn_r101_fpn_1x.py | 24 +- configs/cascade_mask_rcnn_r50_fpn_1x.py | 24 +- configs/cascade_rcnn_r101_fpn_1x.py | 24 +- configs/cascade_rcnn_r50_fpn_1x.py | 24 +- configs/fast_mask_rcnn_r101_fpn_1x.py | 6 +- configs/fast_mask_rcnn_r50_fpn_1x.py | 6 +- configs/fast_rcnn_r101_fpn_1x.py | 6 +- configs/fast_rcnn_r50_fpn_1x.py | 6 +- configs/faster_rcnn_r101_fpn_1x.py | 12 +- configs/faster_rcnn_r50_fpn_1x.py | 12 +- configs/mask_rcnn_r101_fpn_1x.py | 12 +- configs/mask_rcnn_r50_fpn_1x.py | 12 +- configs/retinanet_r101_fpn_1x.py | 6 +- configs/retinanet_r50_fpn_1x.py | 6 +- configs/rpn_r101_fpn_1x.py | 6 +- configs/rpn_r50_fpn_1x.py | 6 +- mmdet/core/anchor/anchor_target.py | 14 +- mmdet/core/bbox/__init__.py | 18 +- mmdet/core/bbox/assign_sampling.py | 35 +++ mmdet/core/bbox/assigners/__init__.py | 5 + mmdet/core/bbox/assigners/assign_result.py | 19 ++ mmdet/core/bbox/assigners/base_assigner.py | 8 + .../max_iou_assigner.py} | 24 +- mmdet/core/bbox/samplers/__init__.py | 13 + mmdet/core/bbox/samplers/base_sampler.py | 64 +++++ mmdet/core/bbox/samplers/combined_sampler.py | 16 ++ .../samplers/instance_balanced_pos_sampler.py | 41 ++++ .../bbox/samplers/iou_balanced_neg_sampler.py | 62 +++++ mmdet/core/bbox/samplers/pseudo_sampler.py | 26 ++ mmdet/core/bbox/samplers/random_sampler.py | 55 +++++ mmdet/core/bbox/samplers/sampling_result.py | 24 ++ mmdet/core/bbox/sampling.py | 227 ------------------ 32 files changed, 488 insertions(+), 355 deletions(-) create mode 100644 mmdet/core/bbox/assign_sampling.py create mode 100644 mmdet/core/bbox/assigners/__init__.py create mode 100644 mmdet/core/bbox/assigners/assign_result.py create mode 100644 mmdet/core/bbox/assigners/base_assigner.py rename mmdet/core/bbox/{assignment.py => assigners/max_iou_assigner.py} (88%) create mode 100644 mmdet/core/bbox/samplers/__init__.py create mode 100644 mmdet/core/bbox/samplers/base_sampler.py create mode 100644 mmdet/core/bbox/samplers/combined_sampler.py create mode 100644 mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py create mode 100644 mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py create mode 100644 mmdet/core/bbox/samplers/pseudo_sampler.py create mode 100644 mmdet/core/bbox/samplers/random_sampler.py create mode 100644 mmdet/core/bbox/samplers/sampling_result.py delete mode 100644 mmdet/core/bbox/sampling.py diff --git a/configs/cascade_mask_rcnn_r101_fpn_1x.py b/configs/cascade_mask_rcnn_r101_fpn_1x.py index 2cb289e..4466613 100644 --- a/configs/cascade_mask_rcnn_r101_fpn_1x.py +++ b/configs/cascade_mask_rcnn_r101_fpn_1x.py @@ -77,17 +77,17 @@ model = dict( train_cfg = dict( rpn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, - add_gt_as_proposals=False, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, smoothl1_beta=1 / 9.0, @@ -95,49 +95,49 @@ train_cfg = dict( rcnn=[ dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False), dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.6, neg_iou_thr=0.6, min_pos_iou=0.6, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False), dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.7, min_pos_iou=0.7, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False) diff --git a/configs/cascade_mask_rcnn_r50_fpn_1x.py b/configs/cascade_mask_rcnn_r50_fpn_1x.py index 538b468..af39dc6 100644 --- a/configs/cascade_mask_rcnn_r50_fpn_1x.py +++ b/configs/cascade_mask_rcnn_r50_fpn_1x.py @@ -77,17 +77,17 @@ model = dict( train_cfg = dict( rpn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, - add_gt_as_proposals=False, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, smoothl1_beta=1 / 9.0, @@ -95,49 +95,49 @@ train_cfg = dict( rcnn=[ dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False), dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.6, neg_iou_thr=0.6, min_pos_iou=0.6, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False), dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.7, min_pos_iou=0.7, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False) diff --git a/configs/cascade_rcnn_r101_fpn_1x.py b/configs/cascade_rcnn_r101_fpn_1x.py index e30370b..ccacc01 100644 --- a/configs/cascade_rcnn_r101_fpn_1x.py +++ b/configs/cascade_rcnn_r101_fpn_1x.py @@ -66,17 +66,17 @@ model = dict( train_cfg = dict( rpn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, - add_gt_as_proposals=False, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, smoothl1_beta=1 / 9.0, @@ -84,47 +84,47 @@ train_cfg = dict( rcnn=[ dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), pos_weight=-1, debug=False), dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.6, neg_iou_thr=0.6, min_pos_iou=0.6, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), pos_weight=-1, debug=False), dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.7, min_pos_iou=0.7, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), pos_weight=-1, debug=False) ], diff --git a/configs/cascade_rcnn_r50_fpn_1x.py b/configs/cascade_rcnn_r50_fpn_1x.py index 69a2e52..75d9edd 100644 --- a/configs/cascade_rcnn_r50_fpn_1x.py +++ b/configs/cascade_rcnn_r50_fpn_1x.py @@ -66,17 +66,17 @@ model = dict( train_cfg = dict( rpn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, - add_gt_as_proposals=False, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, smoothl1_beta=1 / 9.0, @@ -84,47 +84,47 @@ train_cfg = dict( rcnn=[ dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), pos_weight=-1, debug=False), dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.6, neg_iou_thr=0.6, min_pos_iou=0.6, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), pos_weight=-1, debug=False), dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.7, min_pos_iou=0.7, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), pos_weight=-1, debug=False) ], diff --git a/configs/fast_mask_rcnn_r101_fpn_1x.py b/configs/fast_mask_rcnn_r101_fpn_1x.py index 342d775..fa64d6f 100644 --- a/configs/fast_mask_rcnn_r101_fpn_1x.py +++ b/configs/fast_mask_rcnn_r101_fpn_1x.py @@ -44,17 +44,17 @@ model = dict( train_cfg = dict( rcnn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False)) diff --git a/configs/fast_mask_rcnn_r50_fpn_1x.py b/configs/fast_mask_rcnn_r50_fpn_1x.py index 8863ba6..2005100 100644 --- a/configs/fast_mask_rcnn_r50_fpn_1x.py +++ b/configs/fast_mask_rcnn_r50_fpn_1x.py @@ -44,17 +44,17 @@ model = dict( train_cfg = dict( rcnn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False)) diff --git a/configs/fast_rcnn_r101_fpn_1x.py b/configs/fast_rcnn_r101_fpn_1x.py index 66b7e8c..c61b74f 100644 --- a/configs/fast_rcnn_r101_fpn_1x.py +++ b/configs/fast_rcnn_r101_fpn_1x.py @@ -33,17 +33,17 @@ model = dict( train_cfg = dict( rcnn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), pos_weight=-1, debug=False)) test_cfg = dict( diff --git a/configs/fast_rcnn_r50_fpn_1x.py b/configs/fast_rcnn_r50_fpn_1x.py index 57394bc..542e2dd 100644 --- a/configs/fast_rcnn_r50_fpn_1x.py +++ b/configs/fast_rcnn_r50_fpn_1x.py @@ -33,17 +33,17 @@ model = dict( train_cfg = dict( rcnn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), pos_weight=-1, debug=False)) test_cfg = dict( diff --git a/configs/faster_rcnn_r101_fpn_1x.py b/configs/faster_rcnn_r101_fpn_1x.py index 49813d1..2ff48c5 100644 --- a/configs/faster_rcnn_r101_fpn_1x.py +++ b/configs/faster_rcnn_r101_fpn_1x.py @@ -43,34 +43,34 @@ model = dict( train_cfg = dict( rpn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, - add_gt_as_proposals=False, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, smoothl1_beta=1 / 9.0, debug=False), rcnn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), pos_weight=-1, debug=False)) test_cfg = dict( diff --git a/configs/faster_rcnn_r50_fpn_1x.py b/configs/faster_rcnn_r50_fpn_1x.py index 97899af..e88e348 100644 --- a/configs/faster_rcnn_r50_fpn_1x.py +++ b/configs/faster_rcnn_r50_fpn_1x.py @@ -43,34 +43,34 @@ model = dict( train_cfg = dict( rpn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, - add_gt_as_proposals=False, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, smoothl1_beta=1 / 9.0, debug=False), rcnn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), pos_weight=-1, debug=False)) test_cfg = dict( diff --git a/configs/mask_rcnn_r101_fpn_1x.py b/configs/mask_rcnn_r101_fpn_1x.py index 65a93e2..3675a80 100644 --- a/configs/mask_rcnn_r101_fpn_1x.py +++ b/configs/mask_rcnn_r101_fpn_1x.py @@ -54,34 +54,34 @@ model = dict( train_cfg = dict( rpn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, - add_gt_as_proposals=False, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, smoothl1_beta=1 / 9.0, debug=False), rcnn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False)) diff --git a/configs/mask_rcnn_r50_fpn_1x.py b/configs/mask_rcnn_r50_fpn_1x.py index c2ef8fa..364944f 100644 --- a/configs/mask_rcnn_r50_fpn_1x.py +++ b/configs/mask_rcnn_r50_fpn_1x.py @@ -54,34 +54,34 @@ model = dict( train_cfg = dict( rpn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, - add_gt_as_proposals=False, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, smoothl1_beta=1 / 9.0, debug=False), rcnn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False)) diff --git a/configs/retinanet_r101_fpn_1x.py b/configs/retinanet_r101_fpn_1x.py index fd4dd9a..e07d98a 100644 --- a/configs/retinanet_r101_fpn_1x.py +++ b/configs/retinanet_r101_fpn_1x.py @@ -31,7 +31,11 @@ model = dict( # training and testing settings train_cfg = dict( assigner=dict( - pos_iou_thr=0.5, neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1), + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), smoothl1_beta=0.11, gamma=2.0, alpha=0.25, diff --git a/configs/retinanet_r50_fpn_1x.py b/configs/retinanet_r50_fpn_1x.py index 77f67de..2840c06 100644 --- a/configs/retinanet_r50_fpn_1x.py +++ b/configs/retinanet_r50_fpn_1x.py @@ -31,7 +31,11 @@ model = dict( # training and testing settings train_cfg = dict( assigner=dict( - pos_iou_thr=0.5, neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1), + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), smoothl1_beta=0.11, gamma=2.0, alpha=0.25, diff --git a/configs/rpn_r101_fpn_1x.py b/configs/rpn_r101_fpn_1x.py index 90f260d..450215e 100644 --- a/configs/rpn_r101_fpn_1x.py +++ b/configs/rpn_r101_fpn_1x.py @@ -28,17 +28,17 @@ model = dict( train_cfg = dict( rpn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, - add_gt_as_proposals=False, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, smoothl1_beta=1 / 9.0, diff --git a/configs/rpn_r50_fpn_1x.py b/configs/rpn_r50_fpn_1x.py index 8e2b402..3af2649 100644 --- a/configs/rpn_r50_fpn_1x.py +++ b/configs/rpn_r50_fpn_1x.py @@ -28,17 +28,17 @@ model = dict( train_cfg = dict( rpn=dict( assigner=dict( + type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1), sampler=dict( + type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, - add_gt_as_proposals=False, - pos_balance_sampling=False, - neg_balance_thr=0), + add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, smoothl1_beta=1 / 9.0, diff --git a/mmdet/core/anchor/anchor_target.py b/mmdet/core/anchor/anchor_target.py index 41086cf..5fcdb83 100644 --- a/mmdet/core/anchor/anchor_target.py +++ b/mmdet/core/anchor/anchor_target.py @@ -1,6 +1,6 @@ import torch -from ..bbox import assign_and_sample, BBoxAssigner, SamplingResult, bbox2delta +from ..bbox import assign_and_sample, build_assigner, PseudoSampler, bbox2delta from ..utils import multi_apply @@ -107,16 +107,12 @@ def anchor_target_single(flat_anchors, assign_result, sampling_result = assign_and_sample( anchors, gt_bboxes, None, None, cfg) else: - bbox_assigner = BBoxAssigner(**cfg.assigner) + bbox_assigner = build_assigner(cfg.assigner) assign_result = bbox_assigner.assign(anchors, gt_bboxes, None, gt_labels) - pos_inds = torch.nonzero( - assign_result.gt_inds > 0).squeeze(-1).unique() - neg_inds = torch.nonzero( - assign_result.gt_inds == 0).squeeze(-1).unique() - gt_flags = anchors.new_zeros(anchors.shape[0], dtype=torch.uint8) - sampling_result = SamplingResult(pos_inds, neg_inds, anchors, - gt_bboxes, assign_result, gt_flags) + bbox_sampler = PseudoSampler() + sampling_result = bbox_sampler.sample(assign_result, anchors, + gt_bboxes) num_valid_anchors = anchors.shape[0] bbox_targets = torch.zeros_like(anchors) diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index 2ed869f..496bd7a 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -1,14 +1,18 @@ from .geometry import bbox_overlaps -from .assignment import BBoxAssigner, AssignResult -from .sampling import (BBoxSampler, SamplingResult, assign_and_sample, - random_choice) +from .assigners import BaseAssigner, MaxIoUAssigner, AssignResult +from .samplers import (BaseSampler, PseudoSampler, RandomSampler, + InstanceBalancedPosSampler, IoUBalancedNegSampler, + CombinedSampler, SamplingResult) +from .assign_sampling import build_assigner, build_sampler, assign_and_sample from .transforms import (bbox2delta, delta2bbox, bbox_flip, bbox_mapping, bbox_mapping_back, bbox2roi, roi2bbox, bbox2result) from .bbox_target import bbox_target __all__ = [ - 'bbox_overlaps', 'BBoxAssigner', 'AssignResult', 'BBoxSampler', - 'SamplingResult', 'assign_and_sample', 'random_choice', 'bbox2delta', - 'delta2bbox', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', - 'roi2bbox', 'bbox2result', 'bbox_target' + 'bbox_overlaps', 'BaseAssigner', 'MaxIoUAssigner', 'AssignResult', + 'BaseSampler', 'PseudoSampler', 'RandomSampler', + 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', + 'SamplingResult', 'build_assigner', 'build_sampler', 'assign_and_sample', + 'bbox2delta', 'delta2bbox', 'bbox_flip', 'bbox_mapping', + 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result', 'bbox_target' ] diff --git a/mmdet/core/bbox/assign_sampling.py b/mmdet/core/bbox/assign_sampling.py new file mode 100644 index 0000000..b1b199a --- /dev/null +++ b/mmdet/core/bbox/assign_sampling.py @@ -0,0 +1,35 @@ +import mmcv + +from . import assigners, samplers + + +def build_assigner(cfg, default_args=None): + if isinstance(cfg, assigners.BaseAssigner): + return cfg + elif isinstance(cfg, dict): + return mmcv.runner.obj_from_dict( + cfg, assigners, default_args=default_args) + else: + raise TypeError('Invalid type {} for building a sampler'.format( + type(cfg))) + + +def build_sampler(cfg, default_args=None): + if isinstance(cfg, samplers.BaseSampler): + return cfg + elif isinstance(cfg, dict): + return mmcv.runner.obj_from_dict( + cfg, samplers, default_args=default_args) + else: + raise TypeError('Invalid type {} for building a sampler'.format( + type(cfg))) + + +def assign_and_sample(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg): + bbox_assigner = build_assigner(cfg.assigner) + bbox_sampler = build_sampler(cfg.sampler) + assign_result = bbox_assigner.assign(bboxes, gt_bboxes, gt_bboxes_ignore, + gt_labels) + sampling_result = bbox_sampler.sample(assign_result, bboxes, gt_bboxes, + gt_labels) + return assign_result, sampling_result diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py new file mode 100644 index 0000000..40a89e9 --- /dev/null +++ b/mmdet/core/bbox/assigners/__init__.py @@ -0,0 +1,5 @@ +from .base_assigner import BaseAssigner +from .max_iou_assigner import MaxIoUAssigner +from .assign_result import AssignResult + +__all__ = ['BaseAssigner', 'MaxIoUAssigner', 'AssignResult'] diff --git a/mmdet/core/bbox/assigners/assign_result.py b/mmdet/core/bbox/assigners/assign_result.py new file mode 100644 index 0000000..33c761d --- /dev/null +++ b/mmdet/core/bbox/assigners/assign_result.py @@ -0,0 +1,19 @@ +import torch + + +class AssignResult(object): + + def __init__(self, num_gts, gt_inds, max_overlaps, labels=None): + self.num_gts = num_gts + self.gt_inds = gt_inds + self.max_overlaps = max_overlaps + self.labels = labels + + def add_gt_(self, gt_labels): + self_inds = torch.arange( + 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device) + self.gt_inds = torch.cat([self_inds, self.gt_inds]) + self.max_overlaps = torch.cat( + [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps]) + if self.labels is not None: + self.labels = torch.cat([gt_labels, self.labels]) diff --git a/mmdet/core/bbox/assigners/base_assigner.py b/mmdet/core/bbox/assigners/base_assigner.py new file mode 100644 index 0000000..7bd02dc --- /dev/null +++ b/mmdet/core/bbox/assigners/base_assigner.py @@ -0,0 +1,8 @@ +from abc import ABCMeta, abstractmethod + + +class BaseAssigner(metaclass=ABCMeta): + + @abstractmethod + def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): + pass diff --git a/mmdet/core/bbox/assignment.py b/mmdet/core/bbox/assigners/max_iou_assigner.py similarity index 88% rename from mmdet/core/bbox/assignment.py rename to mmdet/core/bbox/assigners/max_iou_assigner.py index 62233af..c43db07 100644 --- a/mmdet/core/bbox/assignment.py +++ b/mmdet/core/bbox/assigners/max_iou_assigner.py @@ -1,9 +1,11 @@ import torch -from .geometry import bbox_overlaps +from .base_assigner import BaseAssigner +from .assign_result import AssignResult +from ..geometry import bbox_overlaps -class BBoxAssigner(object): +class MaxIoUAssigner(BaseAssigner): """Assign a corresponding gt bbox or background to each bbox. Each proposals will be assigned with `-1`, `0`, or a positive integer @@ -135,21 +137,3 @@ class BBoxAssigner(object): return AssignResult( num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels) - - -class AssignResult(object): - - def __init__(self, num_gts, gt_inds, max_overlaps, labels=None): - self.num_gts = num_gts - self.gt_inds = gt_inds - self.max_overlaps = max_overlaps - self.labels = labels - - def add_gt_(self, gt_labels): - self_inds = torch.arange( - 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device) - self.gt_inds = torch.cat([self_inds, self.gt_inds]) - self.max_overlaps = torch.cat( - [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps]) - if self.labels is not None: - self.labels = torch.cat([gt_labels, self.labels]) diff --git a/mmdet/core/bbox/samplers/__init__.py b/mmdet/core/bbox/samplers/__init__.py new file mode 100644 index 0000000..8a1e677 --- /dev/null +++ b/mmdet/core/bbox/samplers/__init__.py @@ -0,0 +1,13 @@ +from .base_sampler import BaseSampler +from .pseudo_sampler import PseudoSampler +from .random_sampler import RandomSampler +from .instance_balanced_pos_sampler import InstanceBalancedPosSampler +from .iou_balanced_neg_sampler import IoUBalancedNegSampler +from .combined_sampler import CombinedSampler +from .sampling_result import SamplingResult + +__all__ = [ + 'BaseSampler', 'PseudoSampler', 'RandomSampler', + 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', + 'SamplingResult' +] diff --git a/mmdet/core/bbox/samplers/base_sampler.py b/mmdet/core/bbox/samplers/base_sampler.py new file mode 100644 index 0000000..6ac300f --- /dev/null +++ b/mmdet/core/bbox/samplers/base_sampler.py @@ -0,0 +1,64 @@ +from abc import ABCMeta, abstractmethod + +import torch + +from .sampling_result import SamplingResult + + +class BaseSampler(metaclass=ABCMeta): + + def __init__(self): + self.pos_sampler = self + self.neg_sampler = self + + @abstractmethod + def _sample_pos(self, assign_result, num_expected): + pass + + @abstractmethod + def _sample_neg(self, assign_result, num_expected): + pass + + def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None): + """Sample positive and negative bboxes. + + This is a simple implementation of bbox sampling given candidates, + assigning results and ground truth bboxes. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + bboxes (Tensor): Boxes to be sampled from. + gt_bboxes (Tensor): Ground truth bboxes. + gt_labels (Tensor, optional): Class labels of ground truth bboxes. + + Returns: + :obj:`SamplingResult`: Sampling result. + """ + bboxes = bboxes[:, :4] + + gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8) + if self.add_gt_as_proposals: + bboxes = torch.cat([gt_bboxes, bboxes], dim=0) + assign_result.add_gt_(gt_labels) + gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) + gt_flags = torch.cat([gt_ones, gt_flags]) + + num_expected_pos = int(self.num * self.pos_fraction) + pos_inds = self.pos_sampler._sample_pos(assign_result, + num_expected_pos) + # We found that sampled indices have duplicated items occasionally. + # (may be a bug of PyTorch) + pos_inds = pos_inds.unique() + num_sampled_pos = pos_inds.numel() + num_expected_neg = self.num - num_sampled_pos + if self.neg_pos_ub >= 0: + _pos = max(1, num_sampled_pos) + neg_upper_bound = int(self.neg_pos_ub * _pos) + if num_expected_neg > neg_upper_bound: + num_expected_neg = neg_upper_bound + neg_inds = self.neg_sampler._sample_neg(assign_result, + num_expected_neg) + neg_inds = neg_inds.unique() + + return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, + assign_result, gt_flags) diff --git a/mmdet/core/bbox/samplers/combined_sampler.py b/mmdet/core/bbox/samplers/combined_sampler.py new file mode 100644 index 0000000..578abe3 --- /dev/null +++ b/mmdet/core/bbox/samplers/combined_sampler.py @@ -0,0 +1,16 @@ +from mmcv.runner import obj_from_dict + +from .random_sampler import RandomSampler +from ..assign_sampling import build_sampler + + +class CombinedSampler(RandomSampler): + + def __init__(self, num, pos_fraction, pos_sampler, neg_sampler, **kwargs): + super(CombinedSampler, self).__init__(num, pos_fraction, **kwargs) + default_args = dict(num=num, pos_fraction=pos_fraction) + default_args.update(kwargs) + self.pos_sampler = build_sampler( + pos_sampler, default_args=default_args) + self.neg_sampler = build_sampler( + neg_sampler, default_args=default_args) diff --git a/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py b/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py new file mode 100644 index 0000000..1d65029 --- /dev/null +++ b/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py @@ -0,0 +1,41 @@ +import numpy as np +import torch + +from .random_sampler import RandomSampler + + +class InstanceBalancedPosSampler(RandomSampler): + + def _sample_pos(self, assign_result, num_expected): + pos_inds = torch.nonzero(assign_result.gt_inds > 0) + if pos_inds.numel() != 0: + pos_inds = pos_inds.squeeze(1) + if pos_inds.numel() <= num_expected: + return pos_inds + else: + unique_gt_inds = assign_result.gt_inds[pos_inds].unique() + num_gts = len(unique_gt_inds) + num_per_gt = int(round(num_expected / float(num_gts)) + 1) + sampled_inds = [] + for i in unique_gt_inds: + inds = torch.nonzero(assign_result.gt_inds == i.item()) + if inds.numel() != 0: + inds = inds.squeeze(1) + else: + continue + if len(inds) > num_per_gt: + inds = self.random_choice(inds, num_per_gt) + sampled_inds.append(inds) + sampled_inds = torch.cat(sampled_inds) + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array( + list(set(pos_inds.cpu()) - set(sampled_inds.cpu()))) + if len(extra_inds) > num_extra: + extra_inds = self.random_choice(extra_inds, num_extra) + extra_inds = torch.from_numpy(extra_inds).to( + assign_result.gt_inds.device).long() + sampled_inds = torch.cat([sampled_inds, extra_inds]) + elif len(sampled_inds) > num_expected: + sampled_inds = self.random_choice(sampled_inds, num_expected) + return sampled_inds diff --git a/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py new file mode 100644 index 0000000..8e593fa --- /dev/null +++ b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py @@ -0,0 +1,62 @@ +import numpy as np +import torch + +from .random_sampler import RandomSampler + + +class IoUBalancedNegSampler(RandomSampler): + + def __init__(self, + num, + pos_fraction, + hard_thr=0.1, + hard_fraction=0.5, + **kwargs): + super(IoUBalancedNegSampler, self).__init__(num, pos_fraction, + **kwargs) + assert hard_thr > 0 + assert 0 < hard_fraction < 1 + self.hard_thr = hard_thr + self.hard_fraction = hard_fraction + + def _sample_neg(self, assign_result, num_expected): + neg_inds = torch.nonzero(assign_result.gt_inds == 0) + if neg_inds.numel() != 0: + neg_inds = neg_inds.squeeze(1) + if len(neg_inds) <= num_expected: + return neg_inds + else: + max_overlaps = assign_result.max_overlaps.cpu().numpy() + # balance sampling for negative samples + neg_set = set(neg_inds.cpu().numpy()) + easy_set = set( + np.where( + np.logical_and(max_overlaps >= 0, + max_overlaps < self.hard_thr))[0]) + hard_set = set(np.where(max_overlaps >= self.hard_thr)[0]) + easy_neg_inds = list(easy_set & neg_set) + hard_neg_inds = list(hard_set & neg_set) + + num_expected_hard = int(num_expected * self.hard_fraction) + if len(hard_neg_inds) > num_expected_hard: + sampled_hard_inds = self.random_choice(hard_neg_inds, + num_expected_hard) + else: + sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int) + num_expected_easy = num_expected - len(sampled_hard_inds) + if len(easy_neg_inds) > num_expected_easy: + sampled_easy_inds = self.random_choice(easy_neg_inds, + num_expected_easy) + else: + sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int) + sampled_inds = np.concatenate((sampled_easy_inds, + sampled_hard_inds)) + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array(list(neg_set - set(sampled_inds))) + if len(extra_inds) > num_extra: + extra_inds = self.random_choice(extra_inds, num_extra) + sampled_inds = np.concatenate((sampled_inds, extra_inds)) + sampled_inds = torch.from_numpy(sampled_inds).long().to( + assign_result.gt_inds.device) + return sampled_inds diff --git a/mmdet/core/bbox/samplers/pseudo_sampler.py b/mmdet/core/bbox/samplers/pseudo_sampler.py new file mode 100644 index 0000000..6c7189c --- /dev/null +++ b/mmdet/core/bbox/samplers/pseudo_sampler.py @@ -0,0 +1,26 @@ +import torch + +from .base_sampler import BaseSampler +from .sampling_result import SamplingResult + + +class PseudoSampler(BaseSampler): + + def __init__(self): + pass + + def _sample_pos(self): + raise NotImplementedError + + def _sample_neg(self): + raise NotImplementedError + + def sample(self, assign_result, bboxes, gt_bboxes): + pos_inds = torch.nonzero( + assign_result.gt_inds > 0).squeeze(-1).unique() + neg_inds = torch.nonzero( + assign_result.gt_inds == 0).squeeze(-1).unique() + gt_flags = bboxes.new_zeros(bboxes.shape[0], dtype=torch.uint8) + sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, + assign_result, gt_flags) + return sampling_result diff --git a/mmdet/core/bbox/samplers/random_sampler.py b/mmdet/core/bbox/samplers/random_sampler.py new file mode 100644 index 0000000..bd84ca6 --- /dev/null +++ b/mmdet/core/bbox/samplers/random_sampler.py @@ -0,0 +1,55 @@ +import numpy as np +import torch + +from .base_sampler import BaseSampler + + +class RandomSampler(BaseSampler): + + def __init__(self, + num, + pos_fraction, + neg_pos_ub=-1, + add_gt_as_proposals=True): + super(RandomSampler, self).__init__() + self.num = num + self.pos_fraction = pos_fraction + self.neg_pos_ub = neg_pos_ub + self.add_gt_as_proposals = add_gt_as_proposals + + @staticmethod + def random_choice(gallery, num): + """Random select some elements from the gallery. + + It seems that Pytorch's implementation is slower than numpy so we use + numpy to randperm the indices. + """ + assert len(gallery) >= num + if isinstance(gallery, list): + gallery = np.array(gallery) + cands = np.arange(len(gallery)) + np.random.shuffle(cands) + rand_inds = cands[:num] + if not isinstance(gallery, np.ndarray): + rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device) + return gallery[rand_inds] + + def _sample_pos(self, assign_result, num_expected): + """Randomly sample some positive samples.""" + pos_inds = torch.nonzero(assign_result.gt_inds > 0) + if pos_inds.numel() != 0: + pos_inds = pos_inds.squeeze(1) + if pos_inds.numel() <= num_expected: + return pos_inds + else: + return self.random_choice(pos_inds, num_expected) + + def _sample_neg(self, assign_result, num_expected): + """Randomly sample some negative samples.""" + neg_inds = torch.nonzero(assign_result.gt_inds == 0) + if neg_inds.numel() != 0: + neg_inds = neg_inds.squeeze(1) + if len(neg_inds) <= num_expected: + return neg_inds + else: + return self.random_choice(neg_inds, num_expected) diff --git a/mmdet/core/bbox/samplers/sampling_result.py b/mmdet/core/bbox/samplers/sampling_result.py new file mode 100644 index 0000000..696e650 --- /dev/null +++ b/mmdet/core/bbox/samplers/sampling_result.py @@ -0,0 +1,24 @@ +import torch + + +class SamplingResult(object): + + def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, + gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_bboxes = bboxes[pos_inds] + self.neg_bboxes = bboxes[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_bboxes.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :] + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def bboxes(self): + return torch.cat([self.pos_bboxes, self.neg_bboxes]) diff --git a/mmdet/core/bbox/sampling.py b/mmdet/core/bbox/sampling.py deleted file mode 100644 index 63d2279..0000000 --- a/mmdet/core/bbox/sampling.py +++ /dev/null @@ -1,227 +0,0 @@ -import numpy as np -import torch - -from .assignment import BBoxAssigner - - -def random_choice(gallery, num): - """Random select some elements from the gallery. - - It seems that Pytorch's implementation is slower than numpy so we use numpy - to randperm the indices. - """ - assert len(gallery) >= num - if isinstance(gallery, list): - gallery = np.array(gallery) - cands = np.arange(len(gallery)) - np.random.shuffle(cands) - rand_inds = cands[:num] - if not isinstance(gallery, np.ndarray): - rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device) - return gallery[rand_inds] - - -def assign_and_sample(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg): - bbox_assigner = BBoxAssigner(**cfg.assigner) - bbox_sampler = BBoxSampler(**cfg.sampler) - assign_result = bbox_assigner.assign(bboxes, gt_bboxes, gt_bboxes_ignore, - gt_labels) - sampling_result = bbox_sampler.sample(assign_result, bboxes, gt_bboxes, - gt_labels) - return assign_result, sampling_result - - -class BBoxSampler(object): - """Sample positive and negative bboxes given assigned results. - - Args: - pos_fraction (float): Positive sample fraction. - neg_pos_ub (float): Negative/Positive upper bound. - pos_balance_sampling (bool): Whether to sample positive samples around - each gt bbox evenly. - neg_balance_thr (float, optional): IoU threshold for simple/hard - negative balance sampling. - neg_hard_fraction (float, optional): Fraction of hard negative samples - for negative balance sampling. - """ - - def __init__(self, - num, - pos_fraction, - neg_pos_ub=-1, - add_gt_as_proposals=True, - pos_balance_sampling=False, - neg_balance_thr=0, - neg_hard_fraction=0.5): - self.num = num - self.pos_fraction = pos_fraction - self.neg_pos_ub = neg_pos_ub - self.add_gt_as_proposals = add_gt_as_proposals - self.pos_balance_sampling = pos_balance_sampling - self.neg_balance_thr = neg_balance_thr - self.neg_hard_fraction = neg_hard_fraction - - def _sample_pos(self, assign_result, num_expected): - """Balance sampling for positive bboxes/anchors. - - 1. calculate average positive num for each gt: num_per_gt - 2. sample at most num_per_gt positives for each gt - 3. random sampling from rest anchors if not enough fg - """ - pos_inds = torch.nonzero(assign_result.gt_inds > 0) - if pos_inds.numel() != 0: - pos_inds = pos_inds.squeeze(1) - if pos_inds.numel() <= num_expected: - return pos_inds - elif not self.pos_balance_sampling: - return random_choice(pos_inds, num_expected) - else: - unique_gt_inds = torch.unique( - assign_result.gt_inds[pos_inds].cpu()) - num_gts = len(unique_gt_inds) - num_per_gt = int(round(num_expected / float(num_gts)) + 1) - sampled_inds = [] - for i in unique_gt_inds: - inds = torch.nonzero(assign_result.gt_inds == i.item()) - if inds.numel() != 0: - inds = inds.squeeze(1) - else: - continue - if len(inds) > num_per_gt: - inds = random_choice(inds, num_per_gt) - sampled_inds.append(inds) - sampled_inds = torch.cat(sampled_inds) - if len(sampled_inds) < num_expected: - num_extra = num_expected - len(sampled_inds) - extra_inds = np.array( - list(set(pos_inds.cpu()) - set(sampled_inds.cpu()))) - if len(extra_inds) > num_extra: - extra_inds = random_choice(extra_inds, num_extra) - extra_inds = torch.from_numpy(extra_inds).to( - assign_result.gt_inds.device).long() - sampled_inds = torch.cat([sampled_inds, extra_inds]) - elif len(sampled_inds) > num_expected: - sampled_inds = random_choice(sampled_inds, num_expected) - return sampled_inds - - def _sample_neg(self, assign_result, num_expected): - """Balance sampling for negative bboxes/anchors. - - Negative samples are split into 2 set: hard (balance_thr <= iou < - neg_iou_thr) and easy (iou < balance_thr). The sampling ratio is - controlled by `hard_fraction`. - """ - neg_inds = torch.nonzero(assign_result.gt_inds == 0) - if neg_inds.numel() != 0: - neg_inds = neg_inds.squeeze(1) - if len(neg_inds) <= num_expected: - return neg_inds - elif self.neg_balance_thr <= 0: - # uniform sampling among all negative samples - return random_choice(neg_inds, num_expected) - else: - max_overlaps = assign_result.max_overlaps.cpu().numpy() - # balance sampling for negative samples - neg_set = set(neg_inds.cpu().numpy()) - easy_set = set( - np.where( - np.logical_and(max_overlaps >= 0, - max_overlaps < self.neg_balance_thr))[0]) - hard_set = set(np.where(max_overlaps >= self.neg_balance_thr)[0]) - easy_neg_inds = list(easy_set & neg_set) - hard_neg_inds = list(hard_set & neg_set) - - num_expected_hard = int(num_expected * self.neg_hard_fraction) - if len(hard_neg_inds) > num_expected_hard: - sampled_hard_inds = random_choice(hard_neg_inds, - num_expected_hard) - else: - sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int) - num_expected_easy = num_expected - len(sampled_hard_inds) - if len(easy_neg_inds) > num_expected_easy: - sampled_easy_inds = random_choice(easy_neg_inds, - num_expected_easy) - else: - sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int) - sampled_inds = np.concatenate((sampled_easy_inds, - sampled_hard_inds)) - if len(sampled_inds) < num_expected: - num_extra = num_expected - len(sampled_inds) - extra_inds = np.array(list(neg_set - set(sampled_inds))) - if len(extra_inds) > num_extra: - extra_inds = random_choice(extra_inds, num_extra) - sampled_inds = np.concatenate((sampled_inds, extra_inds)) - sampled_inds = torch.from_numpy(sampled_inds).long().to( - assign_result.gt_inds.device) - return sampled_inds - - def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None): - """Sample positive and negative bboxes. - - This is a simple implementation of bbox sampling given candidates, - assigning results and ground truth bboxes. - - 1. Assign gt to each bbox. - 2. Add gt bboxes to the sampling pool (optional). - 3. Perform positive and negative sampling. - - Args: - assign_result (:obj:`AssignResult`): Bbox assigning results. - bboxes (Tensor): Boxes to be sampled from. - gt_bboxes (Tensor): Ground truth bboxes. - gt_labels (Tensor, optional): Class labels of ground truth bboxes. - - Returns: - :obj:`SamplingResult`: Sampling result. - """ - bboxes = bboxes[:, :4] - - gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8) - if self.add_gt_as_proposals: - bboxes = torch.cat([gt_bboxes, bboxes], dim=0) - assign_result.add_gt_(gt_labels) - gt_flags = torch.cat([ - bboxes.new_ones((gt_bboxes.shape[0], ), dtype=torch.uint8), - gt_flags - ]) - - num_expected_pos = int(self.num * self.pos_fraction) - pos_inds = self._sample_pos(assign_result, num_expected_pos) - # We found that sampled indices have duplicated items occasionally. - # (mab be a bug of PyTorch) - pos_inds = pos_inds.unique() - num_sampled_pos = pos_inds.numel() - num_expected_neg = self.num - num_sampled_pos - if self.neg_pos_ub >= 0: - num_neg_max = int(self.neg_pos_ub * - num_sampled_pos) if num_sampled_pos > 0 else int( - self.neg_pos_ub) - num_expected_neg = min(num_neg_max, num_expected_neg) - neg_inds = self._sample_neg(assign_result, num_expected_neg) - neg_inds = neg_inds.unique() - - return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, - assign_result, gt_flags) - - -class SamplingResult(object): - - def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, - gt_flags): - self.pos_inds = pos_inds - self.neg_inds = neg_inds - self.pos_bboxes = bboxes[pos_inds] - self.neg_bboxes = bboxes[neg_inds] - self.pos_is_gt = gt_flags[pos_inds] - - self.num_gts = gt_bboxes.shape[0] - self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 - self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :] - if assign_result.labels is not None: - self.pos_gt_labels = assign_result.labels[pos_inds] - else: - self.pos_gt_labels = None - - @property - def bboxes(self): - return torch.cat([self.pos_bboxes, self.neg_bboxes]) -- GitLab