Skip to content
Snippets Groups Projects
Unverified Commit 921e433e authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Support normalization layers in RetinaHead (#557)

* support GN in RetinaHead

* add conv_cfg argument for RetinaHead

* add model upgrading tool and update retinanet model urls

* minor fix for regex strings
parent 51904cbc
No related branches found
No related tags found
No related merge requests found
...@@ -109,15 +109,15 @@ We released RPN, Faster R-CNN and Mask R-CNN models in the first version. More m ...@@ -109,15 +109,15 @@ We released RPN, Faster R-CNN and Mask R-CNN models in the first version. More m
| Backbone | Style | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download | | Backbone | Style | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
|:--------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:| |:--------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:|
| R-50-FPN | caffe | 1x | 6.7 | 0.468 | 9.4 | 35.8 | - | | R-50-FPN | caffe | 1x | 6.7 | 0.468 | 9.4 | 35.8 | - |
| R-50-FPN | pytorch | 1x | 6.9 | 0.496 | 9.1 | 35.6 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r50_fpn_1x_20181125-3d3c2142.pth) | | R-50-FPN | pytorch | 1x | 6.9 | 0.496 | 9.1 | 35.6 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r50_fpn_1x_20181125-7b0c2548.pth) |
| R-50-FPN | pytorch | 2x | 6.9 | 0.496 | 9.1 | 36.5 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r50_fpn_2x_20181125-e0dbec97.pth) | | R-50-FPN | pytorch | 2x | 6.9 | 0.496 | 9.1 | 36.5 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r50_fpn_2x_20181125-8b724df2.pth) |
| R-101-FPN | caffe | 1x | 9.2 | 0.614 | 8.2 | 37.8 | - | | R-101-FPN | caffe | 1x | 9.2 | 0.614 | 8.2 | 37.8 | - |
| R-101-FPN | pytorch | 1x | 9.6 | 0.643 | 8.1 | 37.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r101_fpn_1x_20181129-f738a02f.pth) | | R-101-FPN | pytorch | 1x | 9.6 | 0.643 | 8.1 | 37.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r101_fpn_1x_20181129-f016f384.pth) |
| R-101-FPN | pytorch | 2x | 9.6 | 0.643 | 8.1 | 38.1 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r101_fpn_2x_20181129-f654534b.pth) | | R-101-FPN | pytorch | 2x | 9.6 | 0.643 | 8.1 | 38.1 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r101_fpn_2x_20181129-72c14526.pth) |
| X-101-32x4d-FPN | pytorch | 1x| 10.8 | 0.792 | 6.7 | 38.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_32x4d_fpn_1x_20181218-c140fb82.pth) | X-101-32x4d-FPN | pytorch | 1x| 10.8 | 0.792 | 6.7 | 38.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_32x4d_fpn_1x_20181218-c84d7dfc.pth)
| X-101-32x4d-FPN | pytorch | 2x| 10.8 | 0.792 | 6.7 | 39.3 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_32x4d_fpn_2x_20181218-605dcd0a.pth) | X-101-32x4d-FPN | pytorch | 2x| 10.8 | 0.792 | 6.7 | 39.3 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_32x4d_fpn_2x_20181218-8596452d.pth)
| X-101-64x4d-FPN | pytorch | 1x| 14.6 | 1.128 | 5.3 | 40.0 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_64x4d_fpn_1x_20181218-2f6f778b.pth) | X-101-64x4d-FPN | pytorch | 1x| 14.6 | 1.128 | 5.3 | 40.0 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_64x4d_fpn_1x_20181218-a0a22662.pth)
| X-101-64x4d-FPN | pytorch | 2x| 14.6 | 1.128 | 5.3 | 39.6 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_64x4d_fpn_2x_20181218-2f598dc5.pth) | X-101-64x4d-FPN | pytorch | 2x| 14.6 | 1.128 | 5.3 | 39.6 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_64x4d_fpn_2x_20181218-5e88d045.pth)
### Cascade R-CNN ### Cascade R-CNN
......
...@@ -4,7 +4,7 @@ from mmcv.cnn import normal_init ...@@ -4,7 +4,7 @@ from mmcv.cnn import normal_init
from .anchor_head import AnchorHead from .anchor_head import AnchorHead
from ..registry import HEADS from ..registry import HEADS
from ..utils import bias_init_with_prob from ..utils import bias_init_with_prob, ConvModule
@HEADS.register_module @HEADS.register_module
...@@ -16,10 +16,14 @@ class RetinaHead(AnchorHead): ...@@ -16,10 +16,14 @@ class RetinaHead(AnchorHead):
stacked_convs=4, stacked_convs=4,
octave_base_scale=4, octave_base_scale=4,
scales_per_octave=3, scales_per_octave=3,
conv_cfg=None,
normalize=None,
**kwargs): **kwargs):
self.stacked_convs = stacked_convs self.stacked_convs = stacked_convs
self.octave_base_scale = octave_base_scale self.octave_base_scale = octave_base_scale
self.scales_per_octave = scales_per_octave self.scales_per_octave = scales_per_octave
self.conv_cfg = conv_cfg
self.normalize = normalize
octave_scales = np.array( octave_scales = np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)]) [2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale anchor_scales = octave_scales * octave_base_scale
...@@ -38,9 +42,25 @@ class RetinaHead(AnchorHead): ...@@ -38,9 +42,25 @@ class RetinaHead(AnchorHead):
for i in range(self.stacked_convs): for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append( self.cls_convs.append(
nn.Conv2d(chn, self.feat_channels, 3, stride=1, padding=1)) ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
normalize=self.normalize,
bias=self.normalize is None))
self.reg_convs.append( self.reg_convs.append(
nn.Conv2d(chn, self.feat_channels, 3, stride=1, padding=1)) ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
normalize=self.normalize,
bias=self.normalize is None))
self.retina_cls = nn.Conv2d( self.retina_cls = nn.Conv2d(
self.feat_channels, self.feat_channels,
self.num_anchors * self.cls_out_channels, self.num_anchors * self.cls_out_channels,
...@@ -51,9 +71,9 @@ class RetinaHead(AnchorHead): ...@@ -51,9 +71,9 @@ class RetinaHead(AnchorHead):
def init_weights(self): def init_weights(self):
for m in self.cls_convs: for m in self.cls_convs:
normal_init(m, std=0.01) normal_init(m.conv, std=0.01)
for m in self.reg_convs: for m in self.reg_convs:
normal_init(m, std=0.01) normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01) bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls) normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01) normal_init(self.retina_reg, std=0.01)
...@@ -62,9 +82,9 @@ class RetinaHead(AnchorHead): ...@@ -62,9 +82,9 @@ class RetinaHead(AnchorHead):
cls_feat = x cls_feat = x
reg_feat = x reg_feat = x
for cls_conv in self.cls_convs: for cls_conv in self.cls_convs:
cls_feat = self.relu(cls_conv(cls_feat)) cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs: for reg_conv in self.reg_convs:
reg_feat = self.relu(reg_conv(reg_feat)) reg_feat = reg_conv(reg_feat)
cls_score = self.retina_cls(cls_feat) cls_score = self.retina_cls(cls_feat)
bbox_pred = self.retina_reg(reg_feat) bbox_pred = self.retina_reg(reg_feat)
return cls_score, bbox_pred return cls_score, bbox_pred
import argparse
import re
from collections import OrderedDict
import torch
def convert(in_file, out_file):
"""Convert keys in checkpoints.
There can be some breaking changes during the development of mmdetection,
and this tool is used for upgrading checkpoints trained with old versions
to the latest one.
"""
checkpoint = torch.load(in_file)
in_state_dict = checkpoint.pop('state_dict')
out_state_dict = OrderedDict()
for key, val in in_state_dict.items():
# Use ConvModule instead of nn.Conv2d in RetinaNet
# cls_convs.0.weight -> cls_convs.0.conv.weight
m = re.search(r'(cls_convs|reg_convs).\d.(weight|bias)', key)
if m is not None:
param = m.groups()[1]
new_key = key.replace(param, 'conv.{}'.format(param))
out_state_dict[new_key] = val
continue
out_state_dict[key] = val
checkpoint['state_dict'] = out_state_dict
torch.save(checkpoint, out_file)
def main():
parser = argparse.ArgumentParser(description='Upgrade model version')
parser.add_argument('in_file', help='input checkpoint file')
parser.add_argument('out_file', help='output checkpoint file')
args = parser.parse_args()
convert(args.in_file, args.out_file)
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment