From ddfb38efa733f52ced8d02b03c9fd913e5d7e044 Mon Sep 17 00:00:00 2001 From: Jerry XU <xvjiarui0826@gmail.com> Date: Wed, 15 May 2019 13:08:55 +0800 Subject: [PATCH] add pytorch 1.1.0 SyncBN support (#577) * add pytorch 1.1.0 SyncBN support * change BatchNorm2d to _BatchNorm and call freeze after train * add freeze back to init function * fixed indentation typo in adding freeze * use SyncBN protect member func to set ddp_gpu_num * Update README.md update pytorch version to 1.1 --- README.md | 2 +- mmdet/models/backbones/resnet.py | 6 ++++-- mmdet/models/utils/norm.py | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 04411b5..db83b3a 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ ## Introduction -The master branch works with **PyTorch 1.0** or higher. If you would like to use PyTorch 0.4.1, +The master branch works with **PyTorch 1.1** or higher. If you would like to use PyTorch 0.4.1, please checkout to the [pytorch-0.4.1](https://github.com/open-mmlab/mmdetection/tree/pytorch-0.4.1) branch. mmdetection is an open source object detection toolbox based on PyTorch. It is diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py index 8752228..564b565 100644 --- a/mmdet/models/backbones/resnet.py +++ b/mmdet/models/backbones/resnet.py @@ -2,6 +2,7 @@ import logging import torch.nn as nn import torch.utils.checkpoint as cp +from torch.nn.modules.batchnorm import _BatchNorm from mmcv.cnn import constant_init, kaiming_init from mmcv.runner import load_checkpoint @@ -437,7 +438,7 @@ class ResNet(nn.Module): for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) if self.dcn is not None: @@ -470,8 +471,9 @@ class ResNet(nn.Module): def train(self, mode=True): super(ResNet, self).train(mode) + self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only - if isinstance(m, nn.BatchNorm2d): + if isinstance(m, _BatchNorm): m.eval() diff --git a/mmdet/models/utils/norm.py b/mmdet/models/utils/norm.py index 8658f6b..a912a19 100644 --- a/mmdet/models/utils/norm.py +++ b/mmdet/models/utils/norm.py @@ -4,7 +4,7 @@ import torch.nn as nn norm_cfg = { # format: layer_type: (abbreviation, module) 'BN': ('bn', nn.BatchNorm2d), - 'SyncBN': ('bn', None), + 'SyncBN': ('bn', nn.SyncBatchNorm), 'GN': ('gn', nn.GroupNorm), # and potentially 'SN' } @@ -44,6 +44,8 @@ def build_norm_layer(cfg, num_features, postfix=''): cfg_.setdefault('eps', 1e-5) if layer_type != 'GN': layer = norm_layer(num_features, **cfg_) + if layer_type == 'SyncBN': + layer._specify_ddp_gpu_num(1) else: assert 'num_groups' in cfg_ layer = norm_layer(num_channels=num_features, **cfg_) -- GitLab