diff --git a/mmdet/nn/parallel/distributed.py b/mmdet/nn/parallel/distributed.py
index 1db0ea6d00f7a30eb9bf12e753530642756b7c21..a2e1d557b3edd5a306aa7abe687fd91bd42ab1e8 100644
--- a/mmdet/nn/parallel/distributed.py
+++ b/mmdet/nn/parallel/distributed.py
@@ -15,8 +15,8 @@ class MMDistributedDataParallel(nn.Module):
         self.dim = dim
         self.broadcast_buffers = broadcast_buffers
 
-        self.first_synced = False
         self.broadcast_bucket_size = 32 * 1024 * 1024
+        self._sync_params()
 
     def _dist_broadcast_coalesced(self, tensors, buffer_size):
         for tensors in _take_tensors(tensors, buffer_size):
@@ -26,7 +26,7 @@ class MMDistributedDataParallel(nn.Module):
                     tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
                 tensor.copy_(synced)
 
-    def sync_params(self):
+    def _sync_params(self):
         module_states = list(self.module.state_dict().values())
         if len(module_states) > 0:
             self._dist_broadcast_coalesced(module_states,
@@ -41,9 +41,6 @@ class MMDistributedDataParallel(nn.Module):
         return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
 
     def forward(self, *inputs, **kwargs):
-        if not self.first_synced:
-            self.sync_params()
-            self.first_synced = True
         inputs, kwargs = self.scatter(inputs, kwargs,
                                       [torch.cuda.current_device()])
         return self.module(*inputs[0], **kwargs[0])
diff --git a/tools/train.py b/tools/train.py
index 8acb63084968838753690f2891a388c861805152..8fd43807967fef6b17695158a4f67514b0a0ab5d 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -95,7 +95,7 @@ def main():
     model = build_detector(
         cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
     if dist:
-        model = MMDistributedDataParallel(model).cuda()
+        model = MMDistributedDataParallel(model.cuda())
     else:
         model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()