From 5055cdf2dfba008fbb8e6041a2f525e502e0f277 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Fri, 5 Oct 2018 14:06:47 +0800
Subject: [PATCH] rename resnet style from fb/msra to pytorch/caffe

---
 mmdet/models/backbones/resnet.py     | 30 +++++++++++++++++-----------
 tools/configs/r50_fpn_frcnn_1x.py    |  2 +-
 tools/configs/r50_fpn_maskrcnn_1x.py |  2 +-
 tools/configs/r50_fpn_rpn_1x.py      |  2 +-
 4 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index 458de92..371f4f5 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -27,7 +27,7 @@ class BasicBlock(nn.Module):
                  stride=1,
                  dilation=1,
                  downsample=None,
-                 style='fb'):
+                 style='pytorch'):
         super(BasicBlock, self).__init__()
         self.conv1 = conv3x3(inplanes, planes, stride, dilation)
         self.bn1 = nn.BatchNorm2d(planes)
@@ -66,15 +66,16 @@ class Bottleneck(nn.Module):
                  stride=1,
                  dilation=1,
                  downsample=None,
-                 style='fb',
+                 style='pytorch',
                  with_cp=False):
-        """Bottleneck block
-        if style is "fb", the stride-two layer is the 3x3 conv layer,
-        if style is "msra", the stride-two layer is the first 1x1 conv layer
+        """Bottleneck block.
+
+        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
+        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
         """
         super(Bottleneck, self).__init__()
-        assert style in ['fb', 'msra']
-        if style == 'fb':
+        assert style in ['pytorch', 'caffe']
+        if style == 'pytorch':
             conv1_stride = 1
             conv2_stride = stride
         else:
@@ -141,7 +142,7 @@ def make_res_layer(block,
                    blocks,
                    stride=1,
                    dilation=1,
-                   style='fb',
+                   style='pytorch',
                    with_cp=False):
     downsample = None
     if stride != 1 or inplanes != planes * block.expansion:
@@ -175,7 +176,12 @@ def make_res_layer(block,
 
 class ResHead(nn.Module):
 
-    def __init__(self, block, num_blocks, stride=2, dilation=1, style='fb'):
+    def __init__(self,
+                 block,
+                 num_blocks,
+                 stride=2,
+                 dilation=1,
+                 style='pytorch'):
         self.layer4 = make_res_layer(
             block,
             1024,
@@ -198,7 +204,7 @@ class ResNet(nn.Module):
                  dilations=(1, 1, 1, 1),
                  out_indices=(0, 1, 2, 3),
                  frozen_stages=-1,
-                 style='fb',
+                 style='pytorch',
                  sync_bn=False,
                  with_cp=False,
                  strict_frozen=False):
@@ -237,7 +243,7 @@ class ResNet(nn.Module):
                 style=self.style,
                 with_cp=with_cp)
             self.inplanes = planes * block.expansion
-            setattr(self, layer_name, res_layer)
+            self.add_module(layer_name, res_layer)
             self.res_layers.append(layer_name)
         self.feat_dim = block.expansion * 64 * 2**(len(layers) - 1)
         self.with_cp = with_cp
@@ -314,7 +320,7 @@ def resnet(depth,
            dilations=(1, 1, 1, 1),
            out_indices=(2, ),
            frozen_stages=-1,
-           style='fb',
+           style='pytorch',
            sync_bn=False,
            with_cp=False,
            strict_frozen=False):
diff --git a/tools/configs/r50_fpn_frcnn_1x.py b/tools/configs/r50_fpn_frcnn_1x.py
index e15cbdb..82082df 100644
--- a/tools/configs/r50_fpn_frcnn_1x.py
+++ b/tools/configs/r50_fpn_frcnn_1x.py
@@ -8,7 +8,7 @@ model = dict(
         num_stages=4,
         out_indices=(0, 1, 2, 3),
         frozen_stages=1,
-        style='fb'),
+        style='pytorch'),
     neck=dict(
         type='FPN',
         in_channels=[256, 512, 1024, 2048],
diff --git a/tools/configs/r50_fpn_maskrcnn_1x.py b/tools/configs/r50_fpn_maskrcnn_1x.py
index 5ecdaf4..ad61857 100644
--- a/tools/configs/r50_fpn_maskrcnn_1x.py
+++ b/tools/configs/r50_fpn_maskrcnn_1x.py
@@ -8,7 +8,7 @@ model = dict(
         num_stages=4,
         out_indices=(0, 1, 2, 3),
         frozen_stages=1,
-        style='fb'),
+        style='pytorch'),
     neck=dict(
         type='FPN',
         in_channels=[256, 512, 1024, 2048],
diff --git a/tools/configs/r50_fpn_rpn_1x.py b/tools/configs/r50_fpn_rpn_1x.py
index 91f5f08..dfed976 100644
--- a/tools/configs/r50_fpn_rpn_1x.py
+++ b/tools/configs/r50_fpn_rpn_1x.py
@@ -8,7 +8,7 @@ model = dict(
         num_stages=4,
         out_indices=(0, 1, 2, 3),
         frozen_stages=1,
-        style='fb'),
+        style='pytorch'),
     neck=dict(
         type='FPN',
         in_channels=[256, 512, 1024, 2048],
-- 
GitLab