diff --git a/mmdet/nn/parallel/distributed.py b/mmdet/nn/parallel/distributed.py index 1db0ea6d00f7a30eb9bf12e753530642756b7c21..a2e1d557b3edd5a306aa7abe687fd91bd42ab1e8 100644 --- a/mmdet/nn/parallel/distributed.py +++ b/mmdet/nn/parallel/distributed.py @@ -15,8 +15,8 @@ class MMDistributedDataParallel(nn.Module): self.dim = dim self.broadcast_buffers = broadcast_buffers - self.first_synced = False self.broadcast_bucket_size = 32 * 1024 * 1024 + self._sync_params() def _dist_broadcast_coalesced(self, tensors, buffer_size): for tensors in _take_tensors(tensors, buffer_size): @@ -26,7 +26,7 @@ class MMDistributedDataParallel(nn.Module): tensors, _unflatten_dense_tensors(flat_tensors, tensors)): tensor.copy_(synced) - def sync_params(self): + def _sync_params(self): module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._dist_broadcast_coalesced(module_states, @@ -41,9 +41,6 @@ class MMDistributedDataParallel(nn.Module): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) def forward(self, *inputs, **kwargs): - if not self.first_synced: - self.sync_params() - self.first_synced = True inputs, kwargs = self.scatter(inputs, kwargs, [torch.cuda.current_device()]) return self.module(*inputs[0], **kwargs[0]) diff --git a/tools/train.py b/tools/train.py index 8acb63084968838753690f2891a388c861805152..8fd43807967fef6b17695158a4f67514b0a0ab5d 100644 --- a/tools/train.py +++ b/tools/train.py @@ -95,7 +95,7 @@ def main(): model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) if dist: - model = MMDistributedDataParallel(model).cuda() + model = MMDistributedDataParallel(model.cuda()) else: model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()