diff --git a/mmdet/models/__init__.py b/mmdet/models/__init__.py
index 2209550509f71a71a66b2582440986eebcf3926c..07930688e533e6c65a4fce93209c495eeb17e756 100644
--- a/mmdet/models/__init__.py
+++ b/mmdet/models/__init__.py
@@ -1 +1,2 @@
-from .detectors import Detector
+from .detectors import *
+from .builder import *
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index f8203accd4b335886b7ebffd59517bdc8568769e..51bacc49970ae2bd698736f8d2f2bee21a5c0dde 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -1,7 +1,7 @@
 import math
 import torch.nn as nn
 import torch.utils.checkpoint as cp
-from torchpack import load_checkpoint
+from mmcv.torchpack import load_checkpoint
 
 
 def conv3x3(in_planes, out_planes, stride=1, dilation=1):
diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py
index 5f6e1136eed45abe85a710170e76e04cba0e91cf..da923ecf2d03f3abedee00b526b2e062866e62d0 100644
--- a/mmdet/models/bbox_heads/bbox_head.py
+++ b/mmdet/models/bbox_heads/bbox_head.py
@@ -60,7 +60,7 @@ class BBoxHead(nn.Module):
         return cls_score, bbox_pred
 
     def get_bbox_target(self, pos_proposals, neg_proposals, pos_gt_bboxes,
-                    pos_gt_labels, rcnn_train_cfg):
+                        pos_gt_labels, rcnn_train_cfg):
         reg_num_classes = 1 if self.reg_class_agnostic else self.num_classes
         cls_reg_targets = bbox_target(
             pos_proposals,
@@ -85,7 +85,7 @@ class BBoxHead(nn.Module):
                 bbox_pred,
                 bbox_targets,
                 bbox_weights,
-                ave_factor=bbox_targets.size(0))
+                avg_factor=bbox_targets.size(0))
         return losses
 
     def get_det_bboxes(self,
diff --git a/mmdet/models/builder.py b/mmdet/models/builder.py
index c3b058507fcdc461a9d3b0271858522e4ba0f1ce..4bbc94aa41bb8f3ac823a07a00855a6c65c0c1e0 100644
--- a/mmdet/models/builder.py
+++ b/mmdet/models/builder.py
@@ -1,27 +1,26 @@
-import mmcv
-from mmcv import torchpack
+from mmcv import torchpack as tp
 from torch import nn
 
 from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads,
-               mask_heads)
+               mask_heads, detectors)
 
 __all__ = [
     'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
-    'build_bbox_head', 'build_mask_head'
+    'build_bbox_head', 'build_mask_head', 'build_detector'
 ]
 
 
-def _build_module(cfg, parrent=None):
-    return cfg if isinstance(cfg, nn.Module) else torchpack.obj_from_dict(
-        cfg, parrent)
+def _build_module(cfg, parrent=None, default_args=None):
+    return cfg if isinstance(cfg, nn.Module) else tp.obj_from_dict(
+        cfg, parrent, default_args)
 
 
-def build(cfg, parrent=None):
+def build(cfg, parrent=None, default_args=None):
     if isinstance(cfg, list):
-        modules = [_build_module(cfg_, parrent) for cfg_ in cfg]
+        modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg]
         return nn.Sequential(*modules)
     else:
-        return _build_module(cfg, parrent)
+        return _build_module(cfg, parrent, default_args)
 
 
 def build_backbone(cfg):
@@ -46,3 +45,7 @@ def build_bbox_head(cfg):
 
 def build_mask_head(cfg):
     return build(cfg, mask_heads)
+
+
+def build_detector(cfg, train_cfg=None, test_cfg=None):
+    return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
index 5b690f8d77d6d8eae1adc4bf8b04d3dd3db3462a..fe3fc62a8191d57b673885a567afb6ecf44c42f9 100644
--- a/mmdet/models/detectors/__init__.py
+++ b/mmdet/models/detectors/__init__.py
@@ -1 +1,4 @@
-from .detector import Detector
+from .base import BaseDetector
+from .rpn import RPN
+
+__all__ = ['BaseDetector', 'RPN']
diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..494f62208b10554b3e14bb84c2f3be536f324426
--- /dev/null
+++ b/mmdet/models/detectors/base.py
@@ -0,0 +1,66 @@
+from abc import ABCMeta, abstractmethod
+
+import torch
+import torch.nn as nn
+
+
+class BaseDetector(nn.Module):
+    """Base class for detectors"""
+
+    __metaclass__ = ABCMeta
+
+    def __init__(self):
+        super(BaseDetector, self).__init__()
+
+    @abstractmethod
+    def init_weights(self):
+        pass
+
+    @abstractmethod
+    def extract_feat(self, imgs):
+        pass
+
+    def extract_feats(self, imgs):
+        if isinstance(imgs, torch.Tensor):
+            return self.extract_feat(imgs)
+        elif isinstance(imgs, list):
+            for img in imgs:
+                yield self.extract_feat(img)
+
+    @abstractmethod
+    def forward_train(self, imgs, img_metas, **kwargs):
+        pass
+
+    @abstractmethod
+    def simple_test(self, img, img_meta, **kwargs):
+        pass
+
+    @abstractmethod
+    def aug_test(self, imgs, img_metas, **kwargs):
+        pass
+
+    def forward_test(self, imgs, img_metas, **kwargs):
+        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+            if not isinstance(var, list):
+                raise TypeError('{} must be a list, but got {}'.format(
+                    name, type(var)))
+
+        num_augs = len(imgs)
+        if num_augs != len(img_metas):
+            raise ValueError(
+                'num of augmentations ({}) != num of image meta ({})'.format(
+                    len(imgs), len(img_metas)))
+        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
+        imgs_per_gpu = imgs[0].size(0)
+        assert imgs_per_gpu == 1
+
+        if num_augs == 1:
+            return self.simple_test(imgs[0], img_metas[0], **kwargs)
+        else:
+            return self.aug_test(imgs, img_metas, **kwargs)
+
+    def forward(self, img, img_meta, return_loss=True, **kwargs):
+        if return_loss:
+            return self.forward_train(img, img_meta, **kwargs)
+        else:
+            return self.forward_test(img, img_meta, **kwargs)
diff --git a/mmdet/models/detectors/detector.py b/mmdet/models/detectors/detector.py
index 80b7d4438cb59612dbff8a2bf71930eb6383a144..363131e8ecec790a6dac7682a32b181be48cfbd4 100644
--- a/mmdet/models/detectors/detector.py
+++ b/mmdet/models/detectors/detector.py
@@ -8,6 +8,7 @@ from mmdet.core import (bbox2roi, bbox_mapping, split_combined_gt_polys,
 
 
 class Detector(nn.Module):
+
     def __init__(self,
                  backbone,
                  neck=None,
diff --git a/mmdet/models/detectors/faster_rcnn.py b/mmdet/models/detectors/faster_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mmdet/models/detectors/mask_rcnn.py b/mmdet/models/detectors/mask_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mmdet/models/detectors/rpn.py b/mmdet/models/detectors/rpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d3dfd17c6c32f11726bba34d4ae438370ce10fe
--- /dev/null
+++ b/mmdet/models/detectors/rpn.py
@@ -0,0 +1,86 @@
+import mmcv
+
+from mmdet.core import tensor2imgs, bbox_mapping
+from .base import BaseDetector
+from .testing_mixins import RPNTestMixin
+from .. import builder
+
+
+class RPN(BaseDetector, RPNTestMixin):
+
+    def __init__(self,
+                 backbone,
+                 neck,
+                 rpn_head,
+                 train_cfg,
+                 test_cfg,
+                 pretrained=None):
+        super(RPN, self).__init__()
+        self.backbone = builder.build_backbone(backbone)
+        self.neck = builder.build_neck(neck) if neck is not None else None
+        self.rpn_head = builder.build_rpn_head(rpn_head)
+        self.train_cfg = train_cfg
+        self.test_cfg = test_cfg
+        self.init_weights(pretrained=pretrained)
+
+    def init_weights(self, pretrained=None):
+        if pretrained is not None:
+            print('load model from: {}'.format(pretrained))
+        self.backbone.init_weights(pretrained=pretrained)
+        if self.neck is not None:
+            self.neck.init_weights()
+        self.rpn_head.init_weights()
+
+    def extract_feat(self, img):
+        x = self.backbone(img)
+        if self.neck is not None:
+            x = self.neck(x)
+        return x
+
+    def forward_train(self, img, img_meta, gt_bboxes=None):
+        if self.train_cfg.rpn.get('debug', False):
+            self.rpn_head.debug_imgs = tensor2imgs(img)
+
+        x = self.extract_feat(img)
+        rpn_outs = self.rpn_head(x)
+
+        rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn)
+        losses = self.rpn_head.loss(*rpn_loss_inputs)
+        return losses
+
+    def simple_test(self, img, img_meta, rescale=False):
+        x = self.extract_feat(img)
+        proposal_list = self.simple_test_rpn(x, img_meta, self.test_cfg.rpn)
+        if rescale:
+            for proposals, meta in zip(proposal_list, img_meta):
+                proposals[:, :4] /= meta['scale_factor']
+        # TODO: remove this restriction
+        return proposal_list[0].cpu().numpy()
+
+    def aug_test(self, imgs, img_metas, rescale=False):
+        proposal_list = self.aug_test_rpn(
+            self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
+        if not rescale:
+            for proposals, img_meta in zip(proposal_list, img_metas[0]):
+                img_shape = img_meta['img_shape']
+                scale_factor = img_meta['scale_factor']
+                flip = img_meta['flip']
+                proposals[:, :4] = bbox_mapping(proposals[:, :4], img_shape,
+                                                scale_factor, flip)
+        # TODO: remove this restriction
+        return proposal_list[0].cpu().numpy()
+
+    def show_result(self, data, result, img_norm_cfg):
+        """Show RPN proposals on the image.
+
+        Although we assume batch size is 1, this method supports arbitrary
+        batch size.
+        """
+        img_tensor = data['img'][0]
+        img_metas = data['img_meta'][0].data[0]
+        imgs = tensor2imgs(img_tensor, **img_norm_cfg)
+        assert len(imgs) == len(img_metas)
+        for img, img_meta in zip(imgs, img_metas):
+            h, w, _ = img_meta['img_shape']
+            img_show = img[:h, :w, :]
+            mmcv.imshow_bboxes(img_show, result, top_k=20)
diff --git a/mmdet/models/detectors/testing_mixins.py b/mmdet/models/detectors/testing_mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..364fd4e6d1ad3ff64cd0104abe1906aaa15d2619
--- /dev/null
+++ b/mmdet/models/detectors/testing_mixins.py
@@ -0,0 +1,146 @@
+from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_proposals,
+                        merge_aug_bboxes, merge_aug_masks, multiclass_nms)
+
+
+class RPNTestMixin(object):
+
+    def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
+        rpn_outs = self.rpn_head(x)
+        proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
+        proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
+        return proposal_list
+
+    def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
+        imgs_per_gpu = len(img_metas[0])
+        aug_proposals = [[] for _ in range(imgs_per_gpu)]
+        for x, img_meta in zip(feats, img_metas):
+            proposal_list = self.simple_test_rpn(x, img_meta, rpn_test_cfg)
+            for i, proposals in enumerate(proposal_list):
+                aug_proposals[i].append(proposals)
+        # after merging, proposals will be rescaled to the original image size
+        merged_proposals = [
+            merge_aug_proposals(proposals, img_meta, rpn_test_cfg)
+            for proposals, img_meta in zip(aug_proposals, img_metas)
+        ]
+        return merged_proposals
+
+
+class BBoxTestMixin(object):
+
+    def simple_test_bboxes(self,
+                           x,
+                           img_meta,
+                           proposals,
+                           rcnn_test_cfg,
+                           rescale=False):
+        """Test only det bboxes without augmentation."""
+        rois = bbox2roi(proposals)
+        roi_feats = self.bbox_roi_extractor(
+            x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
+        cls_score, bbox_pred = self.bbox_head(roi_feats)
+        img_shape = img_meta[0]['img_shape']
+        scale_factor = img_meta[0]['scale_factor']
+        det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
+            rois,
+            cls_score,
+            bbox_pred,
+            img_shape,
+            scale_factor,
+            rescale=rescale,
+            nms_cfg=rcnn_test_cfg)
+        return det_bboxes, det_labels
+
+    def aug_test_bboxes(self, feats, img_metas, proposals, rcnn_test_cfg):
+        aug_bboxes = []
+        aug_scores = []
+        for x, img_meta in zip(feats, img_metas):
+            # only one image in the batch
+            img_shape = img_meta[0]['img_shape']
+            scale_factor = img_meta[0]['scale_factor']
+            flip = img_meta[0]['flip']
+            proposals = bbox_mapping(proposals[:, :4], img_shape, scale_factor,
+                                     flip)
+            rois = bbox2roi([proposals])
+            # recompute feature maps to save GPU memory
+            roi_feats = self.bbox_roi_extractor(
+                x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
+            cls_score, bbox_pred = self.bbox_head(roi_feats)
+            bboxes, scores = self.bbox_head.get_det_bboxes(
+                rois,
+                cls_score,
+                bbox_pred,
+                img_shape,
+                rescale=False,
+                nms_cfg=None)
+            aug_bboxes.append(bboxes)
+            aug_scores.append(scores)
+        # after merging, bboxes will be rescaled to the original image size
+        merged_bboxes, merged_scores = merge_aug_bboxes(
+            aug_bboxes, aug_scores, img_metas, self.rcnn_test_cfg)
+        det_bboxes, det_labels = multiclass_nms(
+            merged_bboxes, merged_scores, self.rcnn_test_cfg.score_thr,
+            self.rcnn_test_cfg.nms_thr, self.rcnn_test_cfg.max_per_img)
+        return det_bboxes, det_labels
+
+
+class MaskTestMixin(object):
+
+    def simple_test_mask(self,
+                         x,
+                         img_meta,
+                         det_bboxes,
+                         det_labels,
+                         rescale=False):
+        # image shape of the first image in the batch (only one)
+        img_shape = img_meta[0]['img_shape']
+        scale_factor = img_meta[0]['scale_factor']
+        if det_bboxes.shape[0] == 0:
+            segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
+        else:
+            # if det_bboxes is rescaled to the original image size, we need to
+            # rescale it back to the testing scale to obtain RoIs.
+            _bboxes = (det_bboxes[:, :4] * scale_factor
+                       if rescale else det_bboxes)
+            mask_rois = bbox2roi([_bboxes])
+            mask_feats = self.mask_roi_extractor(
+                x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
+            mask_pred = self.mask_head(mask_feats)
+            segm_result = self.mask_head.get_seg_masks(
+                mask_pred, det_bboxes, det_labels, img_shape,
+                self.rcnn_test_cfg, rescale)
+        return segm_result
+
+    def aug_test_mask(self,
+                      feats,
+                      img_metas,
+                      det_bboxes,
+                      det_labels,
+                      rescale=False):
+        if rescale:
+            _det_bboxes = det_bboxes
+        else:
+            _det_bboxes = det_bboxes.clone()
+            _det_bboxes[:, :4] *= img_metas[0][0]['scale_factor']
+        if det_bboxes.shape[0] == 0:
+            segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
+        else:
+            aug_masks = []
+            for x, img_meta in zip(feats, img_metas):
+                img_shape = img_meta[0]['img_shape']
+                scale_factor = img_meta[0]['scale_factor']
+                flip = img_meta[0]['flip']
+                _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+                                       scale_factor, flip)
+                mask_rois = bbox2roi([_bboxes])
+                mask_feats = self.mask_roi_extractor(
+                    x[:len(self.mask_roi_extractor.featmap_strides)],
+                    mask_rois)
+                mask_pred = self.mask_head(mask_feats)
+                # convert to numpy array to save memory
+                aug_masks.append(mask_pred.sigmoid().cpu().numpy())
+            merged_masks = merge_aug_masks(aug_masks, img_metas,
+                                           self.rcnn_test_cfg)
+            segm_result = self.mask_head.get_seg_masks(
+                merged_masks, _det_bboxes, det_labels,
+                img_metas[0]['shape_scale'][0], self.rcnn_test_cfg, rescale)
+        return segm_result
diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..db497fd698663a1a653f5087e2922ddddd9ed2da
--- /dev/null
+++ b/mmdet/models/detectors/two_stage.py
@@ -0,0 +1,149 @@
+import torch
+import torch.nn as nn
+
+from .base import Detector
+from .testing_mixins import RPNTestMixin, BBoxTestMixin
+from .. import builder
+from mmdet.core import bbox2roi, bbox2result, sample_proposals
+
+
+class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
+
+    def __init__(self,
+                 backbone,
+                 neck=None,
+                 rpn_head=None,
+                 bbox_roi_extractor=None,
+                 bbox_head=None,
+                 train_cfg=None,
+                 test_cfg=None,
+                 pretrained=None):
+        super(Detector, self).__init__()
+        self.backbone = builder.build_backbone(backbone)
+
+        self.with_neck = True if neck is not None else False
+        if self.with_neck:
+            self.neck = builder.build_neck(neck)
+
+        self.with_rpn = True if rpn_head is not None else False
+        if self.with_rpn:
+            self.rpn_head = builder.build_rpn_head(rpn_head)
+
+        self.with_bbox = True if bbox_head is not None else False
+        if self.with_bbox:
+            self.bbox_roi_extractor = builder.build_roi_extractor(
+                bbox_roi_extractor)
+            self.bbox_head = builder.build_bbox_head(bbox_head)
+
+        self.train_cfg = train_cfg
+        self.test_cfg = test_cfg
+
+        self.init_weights(pretrained=pretrained)
+
+    def init_weights(self, pretrained=None):
+        if pretrained is not None:
+            print('load model from: {}'.format(pretrained))
+        self.backbone.init_weights(pretrained=pretrained)
+        if self.with_neck:
+            if isinstance(self.neck, nn.Sequential):
+                for m in self.neck:
+                    m.init_weights()
+            else:
+                self.neck.init_weights()
+        if self.with_rpn:
+            self.rpn_head.init_weights()
+        if self.with_bbox:
+            self.bbox_roi_extractor.init_weights()
+            self.bbox_head.init_weights()
+
+    def extract_feat(self, img):
+        x = self.backbone(img)
+        if self.with_neck:
+            x = self.neck(x)
+        return x
+
+    def forward_train(self,
+                      img,
+                      img_meta,
+                      gt_bboxes,
+                      gt_bboxes_ignore,
+                      gt_labels,
+                      proposals=None):
+        losses = dict()
+
+        x = self.extract_feat(img)
+
+        if self.with_rpn:
+            rpn_outs = self.rpn_head(x)
+            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
+                                          self.train_cfg.rpn)
+            rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
+            losses.update(rpn_losses)
+
+            proposal_inputs = rpn_outs + (img_meta, self.self.test_cfg.rpn)
+            proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
+
+        else:
+            proposal_list = proposals
+
+        (pos_inds, neg_inds, pos_proposals, neg_proposals,
+         pos_assigned_gt_inds,
+         pos_gt_bboxes, pos_gt_labels) = sample_proposals(
+             proposal_list, gt_bboxes, gt_bboxes_ignore, gt_labels,
+             self.train_cfg.rcnn)
+
+        labels, label_weights, bbox_targets, bbox_weights = \
+            self.bbox_head.get_bbox_target(
+                pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
+                self.train_cfg.rcnn)
+
+        rois = bbox2roi([
+            torch.cat([pos, neg], dim=0)
+            for pos, neg in zip(pos_proposals, neg_proposals)
+        ])
+        # TODO: a more flexible way to configurate feat maps
+        roi_feats = self.bbox_roi_extractor(
+            x[:self.bbox_roi_extractor.num_inputs], rois)
+        cls_score, bbox_pred = self.bbox_head(roi_feats)
+
+        loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels,
+                                        label_weights, bbox_targets,
+                                        bbox_weights)
+        losses.update(loss_bbox)
+
+        return losses
+
+    def simple_test(self, img, img_meta, proposals=None, rescale=False):
+        """Test without augmentation."""
+        x = self.extract_feat(img)
+        if proposals is None:
+            proposals = self.simple_test_rpn(x, img_meta)
+        if self.with_bbox:
+            # BUG proposals shape?
+            det_bboxes, det_labels = self.simple_test_bboxes(
+                x, img_meta, [proposals], rescale=rescale)
+            bbox_result = bbox2result(det_bboxes, det_labels,
+                                      self.bbox_head.num_classes)
+            return bbox_result
+        else:
+            proposals[:, :4] /= img_meta['scale_factor'].float()
+            return proposals.cpu().numpy()
+
+    def aug_test(self, imgs, img_metas, rescale=False):
+        """Test with augmentations.
+
+        If rescale is False, then returned bboxes and masks will fit the scale
+        of imgs[0].
+        """
+        proposals = self.aug_test_rpn(
+            self.extract_feats(imgs), img_metas, self.rpn_test_cfg)
+        det_bboxes, det_labels = self.aug_test_bboxes(
+            self.extract_feats(imgs), img_metas, proposals, self.rcnn_test_cfg)
+        if rescale:
+            _det_bboxes = det_bboxes
+        else:
+            _det_bboxes = det_bboxes.clone()
+            _det_bboxes[:, :4] *= img_metas[0]['shape_scale'][0][-1]
+        bbox_result = bbox2result(_det_bboxes, det_labels,
+                                  self.bbox_head.num_classes)
+        return bbox_result
diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py
index 8b5b49826bad94ce00379e60bbafc905b0cba9af..b4e21864bff4d6a8f6bd25d46c5ff81aa3068965 100644
--- a/mmdet/models/necks/fpn.py
+++ b/mmdet/models/necks/fpn.py
@@ -101,7 +101,7 @@ class FPN(nn.Module):
         # build top-down path
         used_backbone_levels = len(laterals)
         for i in range(used_backbone_levels - 1, 0, -1):
-            laterals[i - 1] += F.upsample(
+            laterals[i - 1] += F.interpolate(
                 laterals[i], scale_factor=2, mode='nearest')
 
         # build outputs
diff --git a/mmdet/models/rpn_heads/rpn_head.py b/mmdet/models/rpn_heads/rpn_head.py
index 7ffd441f694b5d6c37d3042bb25088f27b002ea9..e81f19310e8e7e23b5e23be04888e511a7bd897d 100644
--- a/mmdet/models/rpn_heads/rpn_head.py
+++ b/mmdet/models/rpn_heads/rpn_head.py
@@ -9,8 +9,7 @@ from mmdet.core import (AnchorGenerator, anchor_target, bbox_transform_inv,
                         weighted_cross_entropy, weighted_smoothl1,
                         weighted_binary_cross_entropy)
 from mmdet.ops import nms
-from ..utils import multi_apply
-from ..utils import normal_init
+from ..utils import multi_apply, normal_init
 
 
 class RPNHead(nn.Module):
@@ -66,14 +65,14 @@ class RPNHead(nn.Module):
     def forward(self, feats):
         return multi_apply(self.forward_single, feats)
 
-    def get_anchors(self, featmap_sizes, img_shapes):
+    def get_anchors(self, featmap_sizes, img_metas):
         """Get anchors given a list of feature map sizes, and get valid flags
         at the same time. (Extra padding regions should be marked as invalid)
         """
         # calculate actual image shapes
         padded_img_shapes = []
-        for img_shape in img_shapes:
-            h, w = img_shape[:2]
+        for img_meta in img_metas:
+            h, w = img_meta['img_shape'][:2]
             padded_h = int(
                 np.ceil(h / self.coarsest_stride) * self.coarsest_stride)
             padded_w = int(
@@ -83,7 +82,7 @@ class RPNHead(nn.Module):
         # len = feature levels
         anchor_list = []
         # len = imgs per gpu
-        valid_flag_list = [[] for _ in range(len(img_shapes))]
+        valid_flag_list = [[] for _ in range(len(img_metas))]
         for i in range(len(featmap_sizes)):
             anchor_stride = self.anchor_strides[i]
             anchors = self.anchor_generators[i].grid_anchors(
@@ -103,26 +102,22 @@ class RPNHead(nn.Module):
 
     def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights,
                     bbox_targets, bbox_weights, num_total_samples, cfg):
+        # classification loss
         labels = labels.contiguous().view(-1)
         label_weights = label_weights.contiguous().view(-1)
-        bbox_targets = bbox_targets.contiguous().view(-1, 4)
-        bbox_weights = bbox_weights.contiguous().view(-1, 4)
         if self.use_sigmoid_cls:
             rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
                                                   1).contiguous().view(-1)
-            loss_cls = weighted_binary_cross_entropy(
-                rpn_cls_score,
-                labels,
-                label_weights,
-                ave_factor=num_total_samples)
+            criterion = weighted_binary_cross_entropy
         else:
             rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
                                                   1).contiguous().view(-1, 2)
-            loss_cls = weighted_cross_entropy(
-                rpn_cls_score,
-                labels,
-                label_weights,
-                ave_factor=num_total_samples)
+            criterion = weighted_cross_entropy
+        loss_cls = criterion(
+            rpn_cls_score, labels, label_weights, avg_factor=num_total_samples)
+        # regression loss
+        bbox_targets = bbox_targets.contiguous().view(-1, 4)
+        bbox_weights = bbox_weights.contiguous().view(-1, 4)
         rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view(
             -1, 4)
         loss_reg = weighted_smoothl1(
@@ -130,7 +125,7 @@ class RPNHead(nn.Module):
             bbox_targets,
             bbox_weights,
             beta=cfg.smoothl1_beta,
-            ave_factor=num_total_samples)
+            avg_factor=num_total_samples)
         return loss_cls, loss_reg
 
     def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, img_shapes, cfg):
@@ -158,8 +153,8 @@ class RPNHead(nn.Module):
             cfg=cfg)
         return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
 
-    def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_shapes, cfg):
-        img_per_gpu = len(img_shapes)
+    def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_meta, cfg):
+        num_imgs = len(img_meta)
         featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
         mlvl_anchors = [
             self.anchor_generators[idx].grid_anchors(featmap_sizes[idx],
@@ -167,7 +162,7 @@ class RPNHead(nn.Module):
             for idx in range(len(featmap_sizes))
         ]
         proposal_list = []
-        for img_id in range(img_per_gpu):
+        for img_id in range(num_imgs):
             rpn_cls_score_list = [
                 rpn_cls_scores[idx][img_id].detach()
                 for idx in range(len(rpn_cls_scores))
@@ -177,10 +172,9 @@ class RPNHead(nn.Module):
                 for idx in range(len(rpn_bbox_preds))
             ]
             assert len(rpn_cls_score_list) == len(rpn_bbox_pred_list)
-            img_shape = img_shapes[img_id]
             proposals = self._get_proposals_single(
                 rpn_cls_score_list, rpn_bbox_pred_list, mlvl_anchors,
-                img_shape, cfg)
+                img_meta[img_id]['img_shape'], cfg)
             proposal_list.append(proposals)
         return proposal_list
 
@@ -195,7 +189,7 @@ class RPNHead(nn.Module):
             if self.use_sigmoid_cls:
                 rpn_cls_score = rpn_cls_score.permute(1, 2,
                                                       0).contiguous().view(-1)
-                rpn_cls_prob = F.sigmoid(rpn_cls_score)
+                rpn_cls_prob = rpn_cls_score.sigmoid()
                 scores = rpn_cls_prob
             else:
                 rpn_cls_score = rpn_cls_score.permute(1, 2,