From d04fa0f3a2be11ac1f48f817198c53e7e5694b71 Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Tue, 25 Sep 2018 01:51:38 +0800 Subject: [PATCH] move _sync_param from forward to init --- mmdet/nn/parallel/distributed.py | 7 ++----- tools/train.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/mmdet/nn/parallel/distributed.py b/mmdet/nn/parallel/distributed.py index 1db0ea6..a2e1d55 100644 --- a/mmdet/nn/parallel/distributed.py +++ b/mmdet/nn/parallel/distributed.py @@ -15,8 +15,8 @@ class MMDistributedDataParallel(nn.Module): self.dim = dim self.broadcast_buffers = broadcast_buffers - self.first_synced = False self.broadcast_bucket_size = 32 * 1024 * 1024 + self._sync_params() def _dist_broadcast_coalesced(self, tensors, buffer_size): for tensors in _take_tensors(tensors, buffer_size): @@ -26,7 +26,7 @@ class MMDistributedDataParallel(nn.Module): tensors, _unflatten_dense_tensors(flat_tensors, tensors)): tensor.copy_(synced) - def sync_params(self): + def _sync_params(self): module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._dist_broadcast_coalesced(module_states, @@ -41,9 +41,6 @@ class MMDistributedDataParallel(nn.Module): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) def forward(self, *inputs, **kwargs): - if not self.first_synced: - self.sync_params() - self.first_synced = True inputs, kwargs = self.scatter(inputs, kwargs, [torch.cuda.current_device()]) return self.module(*inputs[0], **kwargs[0]) diff --git a/tools/train.py b/tools/train.py index 8acb630..8fd4380 100644 --- a/tools/train.py +++ b/tools/train.py @@ -95,7 +95,7 @@ def main(): model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) if dist: - model = MMDistributedDataParallel(model).cuda() + model = MMDistributedDataParallel(model.cuda()) else: model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() -- GitLab