Skip to content
Snippets Groups Projects
Commit d04fa0f3 authored by Kai Chen's avatar Kai Chen
Browse files

move _sync_param from forward to init

parent b7968de7
No related branches found
No related tags found
No related merge requests found
......@@ -15,8 +15,8 @@ class MMDistributedDataParallel(nn.Module):
self.dim = dim
self.broadcast_buffers = broadcast_buffers
self.first_synced = False
self.broadcast_bucket_size = 32 * 1024 * 1024
self._sync_params()
def _dist_broadcast_coalesced(self, tensors, buffer_size):
for tensors in _take_tensors(tensors, buffer_size):
......@@ -26,7 +26,7 @@ class MMDistributedDataParallel(nn.Module):
tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
tensor.copy_(synced)
def sync_params(self):
def _sync_params(self):
module_states = list(self.module.state_dict().values())
if len(module_states) > 0:
self._dist_broadcast_coalesced(module_states,
......@@ -41,9 +41,6 @@ class MMDistributedDataParallel(nn.Module):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
if not self.first_synced:
self.sync_params()
self.first_synced = True
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
return self.module(*inputs[0], **kwargs[0])
......@@ -95,7 +95,7 @@ def main():
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
if dist:
model = MMDistributedDataParallel(model).cuda()
model = MMDistributedDataParallel(model.cuda())
else:
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment