Skip to content
Snippets Groups Projects
Commit 0e747be8 authored by Kai Chen's avatar Kai Chen
Browse files

update resnet backbone

parent e8397e43
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ model = dict( ...@@ -3,7 +3,7 @@ model = dict(
type='FasterRCNN', type='FasterRCNN',
pretrained='modelzoo://resnet50', pretrained='modelzoo://resnet50',
backbone=dict( backbone=dict(
type='resnet', type='ResNet',
depth=50, depth=50,
num_stages=4, num_stages=4,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
......
...@@ -3,7 +3,7 @@ model = dict( ...@@ -3,7 +3,7 @@ model = dict(
type='MaskRCNN', type='MaskRCNN',
pretrained='modelzoo://resnet50', pretrained='modelzoo://resnet50',
backbone=dict( backbone=dict(
type='resnet', type='ResNet',
depth=50, depth=50,
num_stages=4, num_stages=4,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
......
...@@ -3,7 +3,7 @@ model = dict( ...@@ -3,7 +3,7 @@ model = dict(
type='RPN', type='RPN',
pretrained='modelzoo://resnet50', pretrained='modelzoo://resnet50',
backbone=dict( backbone=dict(
type='resnet', type='ResNet',
depth=50, depth=50,
num_stages=4, num_stages=4,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
......
from .resnet import resnet from .resnet import ResNet
__all__ = ['resnet'] __all__ = ['ResNet']
import logging import logging
import math
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
...@@ -27,7 +28,8 @@ class BasicBlock(nn.Module): ...@@ -27,7 +28,8 @@ class BasicBlock(nn.Module):
stride=1, stride=1,
dilation=1, dilation=1,
downsample=None, downsample=None,
style='pytorch'): style='pytorch',
with_cp=False):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(planes)
...@@ -37,6 +39,7 @@ class BasicBlock(nn.Module): ...@@ -37,6 +39,7 @@ class BasicBlock(nn.Module):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
self.dilation = dilation self.dilation = dilation
assert not with_cp
def forward(self, x): def forward(self, x):
residual = x residual = x
...@@ -69,7 +72,6 @@ class Bottleneck(nn.Module): ...@@ -69,7 +72,6 @@ class Bottleneck(nn.Module):
style='pytorch', style='pytorch',
with_cp=False): with_cp=False):
"""Bottleneck block. """Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer. if it is "caffe", the stride-two layer is the first 1x1 conv layer.
""" """
...@@ -174,64 +176,73 @@ def make_res_layer(block, ...@@ -174,64 +176,73 @@ def make_res_layer(block,
return nn.Sequential(*layers) return nn.Sequential(*layers)
class ResHead(nn.Module): class ResNet(nn.Module):
"""ResNet backbone.
def __init__(self,
block,
num_blocks,
stride=2,
dilation=1,
style='pytorch'):
self.layer4 = make_res_layer(
block,
1024,
512,
num_blocks,
stride=stride,
dilation=dilation,
style=style)
def forward(self, x):
return self.layer4(x)
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
class ResNet(nn.Module): arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, def __init__(self,
block, depth,
layers, num_stages=4,
strides=(1, 2, 2, 2), strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1), dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
frozen_stages=-1,
style='pytorch', style='pytorch',
sync_bn=False, frozen_stages=-1,
with_cp=False, bn_eval=True,
strict_frozen=False): bn_frozen=False,
with_cp=False):
super(ResNet, self).__init__() super(ResNet, self).__init__()
if not len(layers) == len(strides) == len(dilations): if depth not in self.arch_settings:
raise ValueError( raise KeyError('invalid depth {} for resnet'.format(depth))
'The number of layers, strides and dilations must be equal, ' assert num_stages >= 1 and num_stages <= 4
'but found have {} layers, {} strides and {} dilations'.format( block, stage_blocks = self.arch_settings[depth]
len(layers), len(strides), len(dilations))) stage_blocks = stage_blocks[:num_stages]
assert max(out_indices) < len(layers) assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages
self.out_indices = out_indices self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.style = style self.style = style
self.sync_bn = sync_bn self.frozen_stages = frozen_stages
self.bn_eval = bn_eval
self.bn_frozen = bn_frozen
self.with_cp = with_cp
self.inplanes = 64 self.inplanes = 64
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False) 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.res_layers = []
for i, num_blocks in enumerate(layers):
self.res_layers = []
for i, num_blocks in enumerate(stage_blocks):
stride = strides[i] stride = strides[i]
dilation = dilations[i] dilation = dilations[i]
layer_name = 'layer{}'.format(i + 1)
planes = 64 * 2**i planes = 64 * 2**i
res_layer = make_res_layer( res_layer = make_res_layer(
block, block,
...@@ -243,12 +254,11 @@ class ResNet(nn.Module): ...@@ -243,12 +254,11 @@ class ResNet(nn.Module):
style=self.style, style=self.style,
with_cp=with_cp) with_cp=with_cp)
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer) self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name) self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(layers) - 1)
self.with_cp = with_cp
self.strict_frozen = strict_frozen self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
...@@ -257,11 +267,9 @@ class ResNet(nn.Module): ...@@ -257,11 +267,9 @@ class ResNet(nn.Module):
elif pretrained is None: elif pretrained is None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels kaiming_init(m)
nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1) constant_init(m, 1)
nn.init.constant_(m.bias, 0)
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
...@@ -283,11 +291,11 @@ class ResNet(nn.Module): ...@@ -283,11 +291,11 @@ class ResNet(nn.Module):
def train(self, mode=True): def train(self, mode=True):
super(ResNet, self).train(mode) super(ResNet, self).train(mode)
if not self.sync_bn: if self.bn_eval:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):
m.eval() m.eval()
if self.strict_frozen: if self.bn_frozen:
for params in m.parameters(): for params in m.parameters():
params.requires_grad = False params.requires_grad = False
if mode and self.frozen_stages >= 0: if mode and self.frozen_stages >= 0:
...@@ -303,39 +311,3 @@ class ResNet(nn.Module): ...@@ -303,39 +311,3 @@ class ResNet(nn.Module):
mod.eval() mod.eval()
for param in mod.parameters(): for param in mod.parameters():
param.requires_grad = False param.requires_grad = False
resnet_cfg = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def resnet(depth,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(2, ),
frozen_stages=-1,
style='pytorch',
sync_bn=False,
with_cp=False,
strict_frozen=False):
"""Constructs a ResNet model.
Args:
depth (int): depth of resnet, from {18, 34, 50, 101, 152}
num_stages (int): num of resnet stages, normally 4
strides (list): strides of the first block of each stage
dilations (list): dilation of each stage
out_indices (list): output from which stages
"""
if depth not in resnet_cfg:
raise KeyError('invalid depth {} for resnet'.format(depth))
block, layers = resnet_cfg[depth]
model = ResNet(block, layers[:num_stages], strides, dilations, out_indices,
frozen_stages, style, sync_bn, with_cp, strict_frozen)
return model
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment