Skip to content
Snippets Groups Projects
Commit 0e747be8 authored by Kai Chen's avatar Kai Chen
Browse files

update resnet backbone

parent e8397e43
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ model = dict(
type='FasterRCNN',
pretrained='modelzoo://resnet50',
backbone=dict(
type='resnet',
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
......
......@@ -3,7 +3,7 @@ model = dict(
type='MaskRCNN',
pretrained='modelzoo://resnet50',
backbone=dict(
type='resnet',
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
......
......@@ -3,7 +3,7 @@ model = dict(
type='RPN',
pretrained='modelzoo://resnet50',
backbone=dict(
type='resnet',
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
......
from .resnet import resnet
from .resnet import ResNet
__all__ = ['resnet']
__all__ = ['ResNet']
import logging
import math
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
......@@ -27,7 +28,8 @@ class BasicBlock(nn.Module):
stride=1,
dilation=1,
downsample=None,
style='pytorch'):
style='pytorch',
with_cp=False):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes)
......@@ -37,6 +39,7 @@ class BasicBlock(nn.Module):
self.downsample = downsample
self.stride = stride
self.dilation = dilation
assert not with_cp
def forward(self, x):
residual = x
......@@ -69,7 +72,6 @@ class Bottleneck(nn.Module):
style='pytorch',
with_cp=False):
"""Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
......@@ -174,64 +176,73 @@ def make_res_layer(block,
return nn.Sequential(*layers)
class ResHead(nn.Module):
def __init__(self,
block,
num_blocks,
stride=2,
dilation=1,
style='pytorch'):
self.layer4 = make_res_layer(
block,
1024,
512,
num_blocks,
stride=stride,
dilation=dilation,
style=style)
def forward(self, x):
return self.layer4(x)
class ResNet(nn.Module):
"""ResNet backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
class ResNet(nn.Module):
arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
block,
layers,
depth,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
style='pytorch',
sync_bn=False,
with_cp=False,
strict_frozen=False):
frozen_stages=-1,
bn_eval=True,
bn_frozen=False,
with_cp=False):
super(ResNet, self).__init__()
if not len(layers) == len(strides) == len(dilations):
raise ValueError(
'The number of layers, strides and dilations must be equal, '
'but found have {} layers, {} strides and {} dilations'.format(
len(layers), len(strides), len(dilations)))
assert max(out_indices) < len(layers)
if depth not in self.arch_settings:
raise KeyError('invalid depth {} for resnet'.format(depth))
assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages]
assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.style = style
self.sync_bn = sync_bn
self.frozen_stages = frozen_stages
self.bn_eval = bn_eval
self.bn_frozen = bn_frozen
self.with_cp = with_cp
self.inplanes = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.res_layers = []
for i, num_blocks in enumerate(layers):
self.res_layers = []
for i, num_blocks in enumerate(stage_blocks):
stride = strides[i]
dilation = dilations[i]
layer_name = 'layer{}'.format(i + 1)
planes = 64 * 2**i
res_layer = make_res_layer(
block,
......@@ -243,12 +254,11 @@ class ResNet(nn.Module):
style=self.style,
with_cp=with_cp)
self.inplanes = planes * block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(layers) - 1)
self.with_cp = with_cp
self.strict_frozen = strict_frozen
self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
......@@ -257,11 +267,9 @@ class ResNet(nn.Module):
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
......@@ -283,11 +291,11 @@ class ResNet(nn.Module):
def train(self, mode=True):
super(ResNet, self).train(mode)
if not self.sync_bn:
if self.bn_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if self.strict_frozen:
if self.bn_frozen:
for params in m.parameters():
params.requires_grad = False
if mode and self.frozen_stages >= 0:
......@@ -303,39 +311,3 @@ class ResNet(nn.Module):
mod.eval()
for param in mod.parameters():
param.requires_grad = False
resnet_cfg = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def resnet(depth,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(2, ),
frozen_stages=-1,
style='pytorch',
sync_bn=False,
with_cp=False,
strict_frozen=False):
"""Constructs a ResNet model.
Args:
depth (int): depth of resnet, from {18, 34, 50, 101, 152}
num_stages (int): num of resnet stages, normally 4
strides (list): strides of the first block of each stage
dilations (list): dilation of each stage
out_indices (list): output from which stages
"""
if depth not in resnet_cfg:
raise KeyError('invalid depth {} for resnet'.format(depth))
block, layers = resnet_cfg[depth]
model = ResNet(block, layers[:num_stages], strides, dilations, out_indices,
frozen_stages, style, sync_bn, with_cp, strict_frozen)
return model
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment