From d04fa0f3a2be11ac1f48f817198c53e7e5694b71 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Tue, 25 Sep 2018 01:51:38 +0800
Subject: [PATCH] move _sync_param from forward to init

---
 mmdet/nn/parallel/distributed.py | 7 ++-----
 tools/train.py                   | 2 +-
 2 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/mmdet/nn/parallel/distributed.py b/mmdet/nn/parallel/distributed.py
index 1db0ea6..a2e1d55 100644
--- a/mmdet/nn/parallel/distributed.py
+++ b/mmdet/nn/parallel/distributed.py
@@ -15,8 +15,8 @@ class MMDistributedDataParallel(nn.Module):
         self.dim = dim
         self.broadcast_buffers = broadcast_buffers
 
-        self.first_synced = False
         self.broadcast_bucket_size = 32 * 1024 * 1024
+        self._sync_params()
 
     def _dist_broadcast_coalesced(self, tensors, buffer_size):
         for tensors in _take_tensors(tensors, buffer_size):
@@ -26,7 +26,7 @@ class MMDistributedDataParallel(nn.Module):
                     tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
                 tensor.copy_(synced)
 
-    def sync_params(self):
+    def _sync_params(self):
         module_states = list(self.module.state_dict().values())
         if len(module_states) > 0:
             self._dist_broadcast_coalesced(module_states,
@@ -41,9 +41,6 @@ class MMDistributedDataParallel(nn.Module):
         return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
 
     def forward(self, *inputs, **kwargs):
-        if not self.first_synced:
-            self.sync_params()
-            self.first_synced = True
         inputs, kwargs = self.scatter(inputs, kwargs,
                                       [torch.cuda.current_device()])
         return self.module(*inputs[0], **kwargs[0])
diff --git a/tools/train.py b/tools/train.py
index 8acb630..8fd4380 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -95,7 +95,7 @@ def main():
     model = build_detector(
         cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
     if dist:
-        model = MMDistributedDataParallel(model).cuda()
+        model = MMDistributedDataParallel(model.cuda())
     else:
         model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
 
-- 
GitLab