From ddfb38efa733f52ced8d02b03c9fd913e5d7e044 Mon Sep 17 00:00:00 2001
From: Jerry XU <xvjiarui0826@gmail.com>
Date: Wed, 15 May 2019 13:08:55 +0800
Subject: [PATCH] add pytorch 1.1.0 SyncBN support (#577)

* add pytorch 1.1.0 SyncBN support

* change BatchNorm2d to _BatchNorm and call freeze after train

* add freeze back to init function

* fixed indentation typo in adding freeze

* use SyncBN protect member func to set ddp_gpu_num

* Update README.md

update pytorch version to 1.1
---
 README.md                        | 2 +-
 mmdet/models/backbones/resnet.py | 6 ++++--
 mmdet/models/utils/norm.py       | 4 +++-
 3 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/README.md b/README.md
index 04411b5..db83b3a 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
 
 ## Introduction
 
-The master branch works with **PyTorch 1.0** or higher. If you would like to use PyTorch 0.4.1,
+The master branch works with **PyTorch 1.1** or higher. If you would like to use PyTorch 0.4.1,
 please checkout to the [pytorch-0.4.1](https://github.com/open-mmlab/mmdetection/tree/pytorch-0.4.1) branch.
 
 mmdetection is an open source object detection toolbox based on PyTorch. It is
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index 8752228..564b565 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -2,6 +2,7 @@ import logging
 
 import torch.nn as nn
 import torch.utils.checkpoint as cp
+from torch.nn.modules.batchnorm import _BatchNorm
 
 from mmcv.cnn import constant_init, kaiming_init
 from mmcv.runner import load_checkpoint
@@ -437,7 +438,7 @@ class ResNet(nn.Module):
             for m in self.modules():
                 if isinstance(m, nn.Conv2d):
                     kaiming_init(m)
-                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                     constant_init(m, 1)
 
             if self.dcn is not None:
@@ -470,8 +471,9 @@ class ResNet(nn.Module):
 
     def train(self, mode=True):
         super(ResNet, self).train(mode)
+        self._freeze_stages()
         if mode and self.norm_eval:
             for m in self.modules():
                 # trick: eval have effect on BatchNorm only
-                if isinstance(m, nn.BatchNorm2d):
+                if isinstance(m, _BatchNorm):
                     m.eval()
diff --git a/mmdet/models/utils/norm.py b/mmdet/models/utils/norm.py
index 8658f6b..a912a19 100644
--- a/mmdet/models/utils/norm.py
+++ b/mmdet/models/utils/norm.py
@@ -4,7 +4,7 @@ import torch.nn as nn
 norm_cfg = {
     # format: layer_type: (abbreviation, module)
     'BN': ('bn', nn.BatchNorm2d),
-    'SyncBN': ('bn', None),
+    'SyncBN': ('bn', nn.SyncBatchNorm),
     'GN': ('gn', nn.GroupNorm),
     # and potentially 'SN'
 }
@@ -44,6 +44,8 @@ def build_norm_layer(cfg, num_features, postfix=''):
     cfg_.setdefault('eps', 1e-5)
     if layer_type != 'GN':
         layer = norm_layer(num_features, **cfg_)
+        if layer_type == 'SyncBN':
+            layer._specify_ddp_gpu_num(1)
     else:
         assert 'num_groups' in cfg_
         layer = norm_layer(num_channels=num_features, **cfg_)
-- 
GitLab