diff --git a/mmdet/core/bbox_ops/sampling.py b/mmdet/core/bbox_ops/sampling.py
index d751f8ede433b593c3b75193440a546fa3f8ce56..bcee761e10e5ef8537235cd629d187f874bb08fd 100644
--- a/mmdet/core/bbox_ops/sampling.py
+++ b/mmdet/core/bbox_ops/sampling.py
@@ -255,38 +255,3 @@ def bbox_sampling(assigned_gt_inds,
                                 neg_hard_fraction)
     neg_inds = neg_inds.unique()
     return pos_inds, neg_inds
-
-
-def sample_proposals(proposals_list, gt_bboxes_list, gt_crowds_list,
-                     gt_labels_list, cfg):
-    cfg_list = [cfg for _ in range(len(proposals_list))]
-    results = map(sample_proposals_single, proposals_list, gt_bboxes_list,
-                  gt_crowds_list, gt_labels_list, cfg_list)
-    # list of tuple to tuple of list
-    return tuple(map(list, zip(*results)))
-
-
-def sample_proposals_single(proposals, gt_bboxes, gt_crowds, gt_labels, cfg):
-    proposals = proposals[:, :4]
-    assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
-        bbox_assign(
-            proposals, gt_bboxes, gt_crowds, gt_labels, cfg.pos_iou_thr,
-            cfg.neg_iou_thr, cfg.pos_iou_thr, cfg.crowd_thr)
-    if cfg.add_gt_as_proposals:
-        proposals = torch.cat([gt_bboxes, proposals], dim=0)
-        gt_assign_self = torch.arange(
-            1, len(gt_labels) + 1, dtype=torch.long, device=proposals.device)
-        assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
-        assigned_labels = torch.cat([gt_labels, assigned_labels])
-
-    pos_inds, neg_inds = bbox_sampling(
-        assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction, cfg.neg_pos_ub,
-        cfg.pos_balance_sampling, max_overlaps, cfg.neg_balance_thr)
-    pos_proposals = proposals[pos_inds]
-    neg_proposals = proposals[neg_inds]
-    pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
-    pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
-    pos_gt_labels = assigned_labels[pos_inds]
-
-    return (pos_inds, neg_inds, pos_proposals, neg_proposals,
-            pos_assigned_gt_inds, pos_gt_bboxes, pos_gt_labels)
diff --git a/mmdet/core/post_processing/merge_augs.py b/mmdet/core/post_processing/merge_augs.py
index 0472aaf80fd8f0c923c6afcdf3902df6037c3edb..2b8d861a6745b90dd33b77ae4bda65bfd825d9a7 100644
--- a/mmdet/core/post_processing/merge_augs.py
+++ b/mmdet/core/post_processing/merge_augs.py
@@ -54,9 +54,9 @@ def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas, rcnn_test_cfg):
     """
     recovered_bboxes = []
     for bboxes, img_info in zip(aug_bboxes, img_metas):
-        img_shape = img_info['img_shape']
-        scale_factor = img_info['scale_factor']
-        flip = img_info['flip']
+        img_shape = img_info[0]['img_shape']
+        scale_factor = img_info[0]['scale_factor']
+        flip = img_info[0]['flip']
         bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip)
         recovered_bboxes.append(bboxes)
     bboxes = torch.stack(recovered_bboxes).mean(dim=0)
@@ -75,7 +75,7 @@ def merge_aug_scores(aug_scores):
         return np.mean(aug_scores, axis=0)
 
 
-def merge_aug_masks(aug_masks, bboxes, img_metas, rcnn_test_cfg, weights=None):
+def merge_aug_masks(aug_masks, img_metas, rcnn_test_cfg, weights=None):
     """Merge augmented mask prediction.
 
     Args:
@@ -87,7 +87,7 @@ def merge_aug_masks(aug_masks, bboxes, img_metas, rcnn_test_cfg, weights=None):
         tuple: (bboxes, scores)
     """
     recovered_masks = [
-        mask if not img_info['flip'][0] else mask[..., ::-1]
+        mask if not img_info[0]['flip'] else mask[..., ::-1]
         for mask, img_info in zip(aug_masks, img_metas)
     ]
     if weights is None:
diff --git a/mmdet/core/utils/__init__.py b/mmdet/core/utils/__init__.py
index 2b6e79d62e60b5e1efaac985e039b36840f86397..30c9c9e5c83797ddf7ea84d088ad29cb8e4b18cc 100644
--- a/mmdet/core/utils/__init__.py
+++ b/mmdet/core/utils/__init__.py
@@ -1,3 +1,12 @@
-from .dist_utils import *
-from .hooks import *
-from .misc import *
+from .dist_utils import (init_dist, reduce_grads, DistOptimizerHook,
+                         DistSamplerSeedHook)
+from .hooks import (EmptyCacheHook, DistEvalHook, DistEvalRecallHook,
+                    CocoDistEvalmAPHook)
+from .misc import tensor2imgs, unmap, results2json, multi_apply
+
+__all__ = [
+    'init_dist', 'reduce_grads', 'DistOptimizerHook', 'DistSamplerSeedHook',
+    'EmptyCacheHook', 'DistEvalHook', 'DistEvalRecallHook',
+    'CocoDistEvalmAPHook', 'tensor2imgs', 'unmap', 'results2json',
+    'multi_apply'
+]
diff --git a/mmdet/core/utils/dist_utils.py b/mmdet/core/utils/dist_utils.py
index 7ffa7a093481098e937cfa1b287b57980dd3d185..4bc986ca73fc6faacb12c6fdbc20f020d6bdb56f 100644
--- a/mmdet/core/utils/dist_utils.py
+++ b/mmdet/core/utils/dist_utils.py
@@ -8,10 +8,6 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
 from torch.nn.utils import clip_grad
 from mmcv.torchpack import Hook, OptimizerHook
 
-__all__ = [
-    'init_dist', 'reduce_grads', 'DistOptimizerHook', 'DistSamplerSeedHook'
-]
-
 
 def init_dist(launcher, backend='nccl', **kwargs):
     if mp.get_start_method(allow_none=True) is None:
diff --git a/mmdet/core/utils/hooks.py b/mmdet/core/utils/hooks.py
index 8a52d11ba411f20d05a2f7ee5d532e5f50412edd..05441601ba792e44b47d89e5f405bb5092286d3f 100644
--- a/mmdet/core/utils/hooks.py
+++ b/mmdet/core/utils/hooks.py
@@ -13,11 +13,6 @@ from pycocotools.cocoeval import COCOeval
 
 from ..eval import eval_recalls
 
-__all__ = [
-    'EmptyCacheHook', 'DistEvalHook', 'DistEvalRecallHook',
-    'CocoDistEvalmAPHook'
-]
-
 
 class EmptyCacheHook(Hook):
 
diff --git a/mmdet/core/utils/misc.py b/mmdet/core/utils/misc.py
index 02d0b40c1e79d289c832d1e15243d862416cb51d..d34ff94302c2a5215681a3e81e38ca9aee8070df 100644
--- a/mmdet/core/utils/misc.py
+++ b/mmdet/core/utils/misc.py
@@ -4,9 +4,6 @@ import mmcv
 import numpy as np
 from six.moves import map, zip
 
-__all__ = ['tensor2imgs', 'multi_apply', 'unmap', 'results2json']
-
-
 def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
     num_imgs = tensor.size(0)
     mean = np.array(mean, dtype=np.float32)
@@ -48,6 +45,21 @@ def xyxy2xywh(bbox):
     ]
 
 
+def proposal2json(dataset, results):
+    json_results = []
+    for idx in range(len(dataset)):
+        img_id = dataset.img_ids[idx]
+        bboxes = results[idx]
+        for i in range(bboxes.shape[0]):
+            data = dict()
+            data['image_id'] = img_id
+            data['bbox'] = xyxy2xywh(bboxes[i])
+            data['score'] = float(bboxes[i][4])
+            data['category_id'] = 1
+            json_results.append(data)
+    return json_results
+
+
 def det2json(dataset, results):
     json_results = []
     for idx in range(len(dataset)):
@@ -85,21 +97,6 @@ def segm2json(dataset, results):
     return json_results
 
 
-def proposal2json(dataset, results):
-    json_results = []
-    for idx in range(len(dataset)):
-        img_id = dataset.img_ids[idx]
-        bboxes = results[idx]
-        for i in range(bboxes.shape[0]):
-            data = dict()
-            data['image_id'] = img_id
-            data['bbox'] = xyxy2xywh(bboxes[i])
-            data['score'] = float(bboxes[i][4])
-            data['category_id'] = 1
-            json_results.append(data)
-    return json_results
-
-
 def results2json(dataset, results, out_file):
     if isinstance(results[0], list):
         json_results = det2json(dataset, results)
diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py
index da923ecf2d03f3abedee00b526b2e062866e62d0..941903aba544b2d5ee7f7f6685664f4ab6f27df4 100644
--- a/mmdet/models/bbox_heads/bbox_head.py
+++ b/mmdet/models/bbox_heads/bbox_head.py
@@ -109,7 +109,7 @@ class BBoxHead(nn.Module):
             # TODO: add clip here
 
         if rescale:
-            bboxes /= scale_factor.float()
+            bboxes /= scale_factor
 
         if nms_cfg is None:
             return bboxes, scores
diff --git a/mmdet/models/detectors/rpn.py b/mmdet/models/detectors/rpn.py
index 8d3dfd17c6c32f11726bba34d4ae438370ce10fe..29173cce7a7baa9620af3924d5ffb02b74dd0c6a 100644
--- a/mmdet/models/detectors/rpn.py
+++ b/mmdet/models/detectors/rpn.py
@@ -2,7 +2,7 @@ import mmcv
 
 from mmdet.core import tensor2imgs, bbox_mapping
 from .base import BaseDetector
-from .testing_mixins import RPNTestMixin
+from .test_mixins import RPNTestMixin
 from .. import builder
 
 
diff --git a/mmdet/models/detectors/testing_mixins.py b/mmdet/models/detectors/test_mixins.py
similarity index 82%
rename from mmdet/models/detectors/testing_mixins.py
rename to mmdet/models/detectors/test_mixins.py
index 364fd4e6d1ad3ff64cd0104abe1906aaa15d2619..2fd3b18d09361e2edc28b67e14b558eb3901fbb7 100644
--- a/mmdet/models/detectors/testing_mixins.py
+++ b/mmdet/models/detectors/test_mixins.py
@@ -50,7 +50,7 @@ class BBoxTestMixin(object):
             nms_cfg=rcnn_test_cfg)
         return det_bboxes, det_labels
 
-    def aug_test_bboxes(self, feats, img_metas, proposals, rcnn_test_cfg):
+    def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
         aug_bboxes = []
         aug_scores = []
         for x, img_meta in zip(feats, img_metas):
@@ -58,8 +58,9 @@ class BBoxTestMixin(object):
             img_shape = img_meta[0]['img_shape']
             scale_factor = img_meta[0]['scale_factor']
             flip = img_meta[0]['flip']
-            proposals = bbox_mapping(proposals[:, :4], img_shape, scale_factor,
-                                     flip)
+            # TODO more flexible
+            proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+                                     scale_factor, flip)
             rois = bbox2roi([proposals])
             # recompute feature maps to save GPU memory
             roi_feats = self.bbox_roi_extractor(
@@ -70,16 +71,17 @@ class BBoxTestMixin(object):
                 cls_score,
                 bbox_pred,
                 img_shape,
+                scale_factor,
                 rescale=False,
                 nms_cfg=None)
             aug_bboxes.append(bboxes)
             aug_scores.append(scores)
         # after merging, bboxes will be rescaled to the original image size
         merged_bboxes, merged_scores = merge_aug_bboxes(
-            aug_bboxes, aug_scores, img_metas, self.rcnn_test_cfg)
+            aug_bboxes, aug_scores, img_metas, self.test_cfg.rcnn)
         det_bboxes, det_labels = multiclass_nms(
-            merged_bboxes, merged_scores, self.rcnn_test_cfg.score_thr,
-            self.rcnn_test_cfg.nms_thr, self.rcnn_test_cfg.max_per_img)
+            merged_bboxes, merged_scores, self.test_cfg.rcnn.score_thr,
+            self.test_cfg.rcnn.nms_thr, self.test_cfg.rcnn.max_per_img)
         return det_bboxes, det_labels
 
 
@@ -92,7 +94,7 @@ class MaskTestMixin(object):
                          det_labels,
                          rescale=False):
         # image shape of the first image in the batch (only one)
-        img_shape = img_meta[0]['img_shape']
+        ori_shape = img_meta[0]['ori_shape']
         scale_factor = img_meta[0]['scale_factor']
         if det_bboxes.shape[0] == 0:
             segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
@@ -106,21 +108,11 @@ class MaskTestMixin(object):
                 x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
             mask_pred = self.mask_head(mask_feats)
             segm_result = self.mask_head.get_seg_masks(
-                mask_pred, det_bboxes, det_labels, img_shape,
-                self.rcnn_test_cfg, rescale)
+                mask_pred, det_bboxes, det_labels, self.test_cfg.rcnn,
+                ori_shape)
         return segm_result
 
-    def aug_test_mask(self,
-                      feats,
-                      img_metas,
-                      det_bboxes,
-                      det_labels,
-                      rescale=False):
-        if rescale:
-            _det_bboxes = det_bboxes
-        else:
-            _det_bboxes = det_bboxes.clone()
-            _det_bboxes[:, :4] *= img_metas[0][0]['scale_factor']
+    def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
         if det_bboxes.shape[0] == 0:
             segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
         else:
@@ -139,8 +131,10 @@ class MaskTestMixin(object):
                 # convert to numpy array to save memory
                 aug_masks.append(mask_pred.sigmoid().cpu().numpy())
             merged_masks = merge_aug_masks(aug_masks, img_metas,
-                                           self.rcnn_test_cfg)
+                                           self.test_cfg.rcnn)
+
+            ori_shape = img_metas[0][0]['ori_shape']
             segm_result = self.mask_head.get_seg_masks(
-                merged_masks, _det_bboxes, det_labels,
-                img_metas[0]['shape_scale'][0], self.rcnn_test_cfg, rescale)
+                merged_masks, det_bboxes, det_labels, self.test_cfg.rcnn,
+                ori_shape)
         return segm_result
diff --git a/tools/configs/r50_fpn_frcnn_1x.py b/tools/configs/r50_fpn_frcnn_1x.py
index 4ce93e623e3d66aaf7c1520bc818e6e33d29e184..156b8b2aa4e2ad7f70d8e1a44682128076e9e34b 100644
--- a/tools/configs/r50_fpn_frcnn_1x.py
+++ b/tools/configs/r50_fpn_frcnn_1x.py
@@ -90,7 +90,11 @@ data = dict(
         img_scale=(1333, 800),
         img_norm_cfg=img_norm_cfg,
         size_divisor=32,
-        flip_ratio=0.5),
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True,
+        test_mode=False),
     test=dict(
         type=dataset_type,
         ann_file=data_root + 'annotations/instances_val2017.json',
@@ -98,7 +102,10 @@ data = dict(
         img_scale=(1333, 800),
         flip_ratio=0,
         img_norm_cfg=img_norm_cfg,
-        size_divisor=32))
+        size_divisor=32,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
 # optimizer
 optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
 optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
@@ -112,7 +119,7 @@ lr_config = dict(
 checkpoint_config = dict(interval=1)
 # yapf:disable
 log_config = dict(
-    interval=50,
+    interval=20,
     hooks=[
         dict(type='TextLoggerHook'),
         # dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log')
@@ -120,7 +127,8 @@ log_config = dict(
 # yapf:enable
 # runtime settings
 total_epochs = 12
-dist_params = dict(backend='nccl')
+device_ids = range(8)
+dist_params = dict(backend='nccl', port='29500')
 log_level = 'INFO'
 work_dir = './work_dirs/fpn_faster_rcnn_r50_1x'
 load_from = None
diff --git a/tools/configs/r50_fpn_maskrcnn_1x.py b/tools/configs/r50_fpn_maskrcnn_1x.py
index 931f051b356c4f01119d20d397ceecd4d3babcf5..5697bca4a58cf7343b348371d56e8f631bda7a5a 100644
--- a/tools/configs/r50_fpn_maskrcnn_1x.py
+++ b/tools/configs/r50_fpn_maskrcnn_1x.py
@@ -103,7 +103,11 @@ data = dict(
         img_scale=(1333, 800),
         img_norm_cfg=img_norm_cfg,
         size_divisor=32,
-        flip_ratio=0.5),
+        flip_ratio=0.5,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True,
+        test_mode=False),
     test=dict(
         type=dataset_type,
         ann_file=data_root + 'annotations/instances_val2017.json',
@@ -111,7 +115,10 @@ data = dict(
         img_scale=(1333, 800),
         flip_ratio=0,
         img_norm_cfg=img_norm_cfg,
-        size_divisor=32))
+        size_divisor=32,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
 # optimizer
 optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
 optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
@@ -120,12 +127,12 @@ lr_config = dict(
     policy='step',
     warmup='linear',
     warmup_iters=500,
-    warmup_ratio=0.333,
+    warmup_ratio=1.0 / 3,
     step=[8, 11])
 checkpoint_config = dict(interval=1)
 # yapf:disable
 log_config = dict(
-    interval=50,
+    interval=20,
     hooks=[
         dict(type='TextLoggerHook'),
         # ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')),
@@ -133,7 +140,8 @@ log_config = dict(
 # yapf:enable
 # runtime settings
 total_epochs = 12
-dist_params = dict(backend='nccl')
+device_ids = range(8)
+dist_params = dict(backend='nccl', port='29500')
 log_level = 'INFO'
 work_dir = './work_dirs/fpn_mask_rcnn_r50_1x'
 load_from = None