diff --git a/README.md b/README.md index 04411b58631e7901c9e3467d1b1b2d9499033e14..db83b3a680518bb0040bc843652ee7407e1ac440 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ ## Introduction -The master branch works with **PyTorch 1.0** or higher. If you would like to use PyTorch 0.4.1, +The master branch works with **PyTorch 1.1** or higher. If you would like to use PyTorch 0.4.1, please checkout to the [pytorch-0.4.1](https://github.com/open-mmlab/mmdetection/tree/pytorch-0.4.1) branch. mmdetection is an open source object detection toolbox based on PyTorch. It is diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py index 8752228a1f9cfe79abb5d1e3bbe44f9330108732..564b5657cb51810efc2241e6964b9f432c9af811 100644 --- a/mmdet/models/backbones/resnet.py +++ b/mmdet/models/backbones/resnet.py @@ -2,6 +2,7 @@ import logging import torch.nn as nn import torch.utils.checkpoint as cp +from torch.nn.modules.batchnorm import _BatchNorm from mmcv.cnn import constant_init, kaiming_init from mmcv.runner import load_checkpoint @@ -437,7 +438,7 @@ class ResNet(nn.Module): for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) if self.dcn is not None: @@ -470,8 +471,9 @@ class ResNet(nn.Module): def train(self, mode=True): super(ResNet, self).train(mode) + self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only - if isinstance(m, nn.BatchNorm2d): + if isinstance(m, _BatchNorm): m.eval() diff --git a/mmdet/models/utils/norm.py b/mmdet/models/utils/norm.py index 8658f6bafcd2caa148752a135dcc5309f3f3f2c5..a912a19b5a3a5de5090865f28cebf230862b6375 100644 --- a/mmdet/models/utils/norm.py +++ b/mmdet/models/utils/norm.py @@ -4,7 +4,7 @@ import torch.nn as nn norm_cfg = { # format: layer_type: (abbreviation, module) 'BN': ('bn', nn.BatchNorm2d), - 'SyncBN': ('bn', None), + 'SyncBN': ('bn', nn.SyncBatchNorm), 'GN': ('gn', nn.GroupNorm), # and potentially 'SN' } @@ -44,6 +44,8 @@ def build_norm_layer(cfg, num_features, postfix=''): cfg_.setdefault('eps', 1e-5) if layer_type != 'GN': layer = norm_layer(num_features, **cfg_) + if layer_type == 'SyncBN': + layer._specify_ddp_gpu_num(1) else: assert 'num_groups' in cfg_ layer = norm_layer(num_channels=num_features, **cfg_)