diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py
index d526b3b3beb37b3177c2bb25f3bfcaaa7fe887d4..f4899f9f7926ab5bc992b322872ff35a08d1bc69 100644
--- a/mmdet/models/anchor_heads/anchor_head.py
+++ b/mmdet/models/anchor_heads/anchor_head.py
@@ -19,7 +19,7 @@ class AnchorHead(nn.Module):
         num_classes (int): Number of categories including the background
             category.
         in_channels (int): Number of channels in the input feature map.
-        feat_channels (int): Number of channels of the feature map.
+        feat_channels (int): Number of hidden channels. Used in child classes.
         anchor_scales (Iterable): Anchor scales.
         anchor_ratios (Iterable): Anchor aspect ratios.
         anchor_strides (Iterable): Anchor strides.
@@ -47,7 +47,6 @@ class AnchorHead(nn.Module):
                  loss_bbox=dict(
                      type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)):
         super(AnchorHead, self).__init__()
-        # NOTE: in_channels is only used in child classes (e.g. RetinaHead)
         self.in_channels = in_channels
         self.num_classes = num_classes
         self.feat_channels = feat_channels
@@ -82,9 +81,9 @@ class AnchorHead(nn.Module):
         self._init_layers()
 
     def _init_layers(self):
-        self.conv_cls = nn.Conv2d(self.feat_channels,
+        self.conv_cls = nn.Conv2d(self.in_channels,
                                   self.num_anchors * self.cls_out_channels, 1)
-        self.conv_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
+        self.conv_reg = nn.Conv2d(self.in_channels, self.num_anchors * 4, 1)
 
     def init_weights(self):
         normal_init(self.conv_cls, std=0.01)
@@ -231,8 +230,7 @@ class AnchorHead(nn.Module):
 
         Example:
             >>> import mmcv
-            >>> self = AnchorHead(num_classes=9, in_channels=1,
-            >>>                   feat_channels=1)
+            >>> self = AnchorHead(num_classes=9, in_channels=1)
             >>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}]
             >>> cfg = mmcv.Config(dict(
             >>>     score_thr=0.00,
diff --git a/mmdet/models/anchor_heads/guided_anchor_head.py b/mmdet/models/anchor_heads/guided_anchor_head.py
index 8c71e85cf6b82f6aa99ce3407a741f6d80b830a8..9e9282509b6bd04a7061594ff8e2acf7f3bcca69 100644
--- a/mmdet/models/anchor_heads/guided_anchor_head.py
+++ b/mmdet/models/anchor_heads/guided_anchor_head.py
@@ -72,7 +72,7 @@ class GuidedAnchorHead(AnchorHead):
     Args:
         num_classes (int): Number of classes.
         in_channels (int): Number of channels in the input feature map.
-        feat_channels (int): Number of channels of the feature map.
+        feat_channels (int): Number of hidden channels.
         octave_base_scale (int): Base octave scale of each level of
             feature map.
         scales_per_octave (int): Number of octave scales in each level of
@@ -170,11 +170,10 @@ class GuidedAnchorHead(AnchorHead):
 
     def _init_layers(self):
         self.relu = nn.ReLU(inplace=True)
-        self.conv_loc = nn.Conv2d(self.feat_channels, 1, 1)
-        self.conv_shape = nn.Conv2d(self.feat_channels, self.num_anchors * 2,
-                                    1)
+        self.conv_loc = nn.Conv2d(self.in_channels, 1, 1)
+        self.conv_shape = nn.Conv2d(self.in_channels, self.num_anchors * 2, 1)
         self.feature_adaption = FeatureAdaption(
-            self.feat_channels,
+            self.in_channels,
             self.feat_channels,
             kernel_size=3,
             deformable_groups=self.deformable_groups)
diff --git a/tests/test_config.py b/tests/test_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc399ff37c4a4ff89b1472191edac1b466480e6
--- /dev/null
+++ b/tests/test_config.py
@@ -0,0 +1,172 @@
+from os.path import dirname, exists, join
+
+
+def _get_config_directory():
+    """ Find the predefined detector config directory """
+    try:
+        # Assume we are running in the source mmdetection repo
+        repo_dpath = dirname(dirname(__file__))
+    except NameError:
+        # For IPython development when this __file__ is not defined
+        import mmdet
+        repo_dpath = dirname(dirname(mmdet.__file__))
+    config_dpath = join(repo_dpath, 'configs')
+    if not exists(config_dpath):
+        raise Exception('Cannot find config path')
+    return config_dpath
+
+
+def test_config_build_detector():
+    """
+    Test that all detection models defined in the configs can be initialized.
+    """
+    from xdoctest.utils import import_module_from_path
+    from mmdet.models import build_detector
+
+    config_dpath = _get_config_directory()
+    print('Found config_dpath = {!r}'.format(config_dpath))
+
+    # import glob
+    # config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
+    # config_names = [relpath(p, config_dpath) for p in config_fpaths]
+
+    # Only tests a representative subset of configurations
+
+    config_names = [
+        # 'dcn/faster_rcnn_dconv_c3-c5_r50_fpn_1x.py',
+        # 'dcn/cascade_mask_rcnn_dconv_c3-c5_r50_fpn_1x.py',
+        # 'dcn/faster_rcnn_dpool_r50_fpn_1x.py',
+        'dcn/mask_rcnn_dconv_c3-c5_r50_fpn_1x.py',
+        # 'dcn/faster_rcnn_dconv_c3-c5_x101_32x4d_fpn_1x.py',
+        # 'dcn/cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py',
+        # 'dcn/faster_rcnn_mdpool_r50_fpn_1x.py',
+        # 'dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py',
+        # 'dcn/faster_rcnn_mdconv_c3-c5_r50_fpn_1x.py',
+        # ---
+        # 'htc/htc_x101_32x4d_fpn_20e_16gpu.py',
+        'htc/htc_without_semantic_r50_fpn_1x.py',
+        # 'htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py',
+        # 'htc/htc_x101_64x4d_fpn_20e_16gpu.py',
+        # 'htc/htc_r50_fpn_1x.py',
+        # 'htc/htc_r101_fpn_20e.py',
+        # 'htc/htc_r50_fpn_20e.py',
+        # ---
+        'cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py',
+        # 'cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py',
+        # ---
+        # 'scratch/scratch_faster_rcnn_r50_fpn_gn_6x.py',
+        # 'scratch/scratch_mask_rcnn_r50_fpn_gn_6x.py',
+        # ---
+        # 'grid_rcnn/grid_rcnn_gn_head_x101_32x4d_fpn_2x.py',
+        'grid_rcnn/grid_rcnn_gn_head_r50_fpn_2x.py',
+        # ---
+        'double_heads/dh_faster_rcnn_r50_fpn_1x.py',
+        # ---
+        'empirical_attention/faster_rcnn_r50_fpn_attention_0010_dcn_1x.py',
+        # 'empirical_attention/faster_rcnn_r50_fpn_attention_1111_1x.py',
+        # 'empirical_attention/faster_rcnn_r50_fpn_attention_0010_1x.py',
+        # 'empirical_attention/faster_rcnn_r50_fpn_attention_1111_dcn_1x.py',
+        # ---
+        # 'ms_rcnn/ms_rcnn_r101_caffe_fpn_1x.py',
+        # 'ms_rcnn/ms_rcnn_x101_64x4d_fpn_1x.py',
+        # 'ms_rcnn/ms_rcnn_r50_caffe_fpn_1x.py',
+        # ---
+        # 'guided_anchoring/ga_faster_x101_32x4d_fpn_1x.py',
+        # 'guided_anchoring/ga_rpn_x101_32x4d_fpn_1x.py',
+        # 'guided_anchoring/ga_retinanet_r50_caffe_fpn_1x.py',
+        # 'guided_anchoring/ga_fast_r50_caffe_fpn_1x.py',
+        # 'guided_anchoring/ga_retinanet_x101_32x4d_fpn_1x.py',
+        # 'guided_anchoring/ga_rpn_r101_caffe_rpn_1x.py',
+        # 'guided_anchoring/ga_faster_r50_caffe_fpn_1x.py',
+        'guided_anchoring/ga_rpn_r50_caffe_fpn_1x.py',
+        # ---
+        'foveabox/fovea_r50_fpn_4gpu_1x.py',
+        # 'foveabox/fovea_align_gn_ms_r101_fpn_4gpu_2x.py',
+        # 'foveabox/fovea_align_gn_r50_fpn_4gpu_2x.py',
+        # 'foveabox/fovea_align_gn_r101_fpn_4gpu_2x.py',
+        'foveabox/fovea_align_gn_ms_r50_fpn_4gpu_2x.py',
+        # ---
+        # 'hrnet/cascade_rcnn_hrnetv2p_w32_20e.py',
+        # 'hrnet/mask_rcnn_hrnetv2p_w32_1x.py',
+        # 'hrnet/cascade_mask_rcnn_hrnetv2p_w32_20e.py',
+        # 'hrnet/htc_hrnetv2p_w32_20e.py',
+        # 'hrnet/faster_rcnn_hrnetv2p_w18_1x.py',
+        # 'hrnet/mask_rcnn_hrnetv2p_w18_1x.py',
+        # 'hrnet/faster_rcnn_hrnetv2p_w32_1x.py',
+        # 'hrnet/faster_rcnn_hrnetv2p_w40_1x.py',
+        'hrnet/fcos_hrnetv2p_w32_gn_1x_4gpu.py',
+        # ---
+        # 'gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py',
+        # 'gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py',
+        'gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py',
+        # 'gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py',
+        # ---
+        # 'wider_face/ssd300_wider_face.py',
+        # ---
+        'pascal_voc/ssd300_voc.py',
+        'pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py',
+        'pascal_voc/ssd512_voc.py',
+        # ---
+        # 'gcnet/mask_rcnn_r4_gcb_c3-c5_r50_fpn_syncbn_1x.py',
+        # 'gcnet/mask_rcnn_r16_gcb_c3-c5_r50_fpn_syncbn_1x.py',
+        # 'gcnet/mask_rcnn_r4_gcb_c3-c5_r50_fpn_1x.py',
+        # 'gcnet/mask_rcnn_r16_gcb_c3-c5_r50_fpn_1x.py',
+        'gcnet/mask_rcnn_r50_fpn_sbn_1x.py',
+        # ---
+        'gn/mask_rcnn_r50_fpn_gn_contrib_2x.py',
+        # 'gn/mask_rcnn_r50_fpn_gn_2x.py',
+        # 'gn/mask_rcnn_r101_fpn_gn_2x.py',
+        # ---
+        # 'reppoints/reppoints_moment_x101_dcn_fpn_2x.py',
+        'reppoints/reppoints_moment_r50_fpn_2x.py',
+        # 'reppoints/reppoints_moment_x101_dcn_fpn_2x_mt.py',
+        'reppoints/reppoints_partial_minmax_r50_fpn_1x.py',
+        'reppoints/bbox_r50_grid_center_fpn_1x.py',
+        # 'reppoints/reppoints_moment_r101_dcn_fpn_2x.py',
+        # 'reppoints/reppoints_moment_r101_fpn_2x_mt.py',
+        # 'reppoints/reppoints_moment_r50_fpn_2x_mt.py',
+        'reppoints/reppoints_minmax_r50_fpn_1x.py',
+        # 'reppoints/reppoints_moment_r50_fpn_1x.py',
+        # 'reppoints/reppoints_moment_r101_fpn_2x.py',
+        # 'reppoints/reppoints_moment_r101_dcn_fpn_2x_mt.py',
+        'reppoints/bbox_r50_grid_fpn_1x.py',
+        # ---
+        # 'fcos/fcos_mstrain_640_800_x101_64x4d_fpn_gn_2x.py',
+        # 'fcos/fcos_mstrain_640_800_r101_caffe_fpn_gn_2x_4gpu.py',
+        'fcos/fcos_r50_caffe_fpn_gn_1x_4gpu.py',
+        # ---
+        'albu_example/mask_rcnn_r50_fpn_1x.py',
+        # ---
+        'libra_rcnn/libra_faster_rcnn_r50_fpn_1x.py',
+        # 'libra_rcnn/libra_retinanet_r50_fpn_1x.py',
+        # 'libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py',
+        # 'libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x.py',
+        # 'libra_rcnn/libra_fast_rcnn_r50_fpn_1x.py',
+        # ---
+        # 'ghm/retinanet_ghm_r50_fpn_1x.py',
+        # ---
+        # 'fp16/retinanet_r50_fpn_fp16_1x.py',
+        'fp16/mask_rcnn_r50_fpn_fp16_1x.py',
+        'fp16/faster_rcnn_r50_fpn_fp16_1x.py'
+    ]
+
+    print('Using {} config files'.format(len(config_names)))
+
+    for config_fname in config_names:
+        config_fpath = join(config_dpath, config_fname)
+        config_mod = import_module_from_path(config_fpath)
+
+        config_mod.model
+        config_mod.train_cfg
+        config_mod.test_cfg
+        print('Building detector, config_fpath = {!r}'.format(config_fpath))
+
+        # Remove pretrained keys to allow for testing in an offline environment
+        if 'pretrained' in config_mod.model:
+            config_mod.model['pretrained'] = None
+
+        detector = build_detector(
+            config_mod.model,
+            train_cfg=config_mod.train_cfg,
+            test_cfg=config_mod.test_cfg)
+        assert detector is not None