Skip to content
Snippets Groups Projects
Commit 82356fd9 authored by Kai Chen's avatar Kai Chen
Browse files

support chunk when reducing grads

parent 3d2b79bd
No related branches found
No related tags found
No related merge requests found
from .dist_utils import (init_dist, reduce_grads, DistOptimizerHook,
DistSamplerSeedHook)
from .dist_utils import init_dist, allreduce_grads, DistOptimizerHook
from .misc import tensor2imgs, unmap, multi_apply
__all__ = [
'init_dist', 'reduce_grads', 'DistOptimizerHook', 'DistSamplerSeedHook',
'tensor2imgs', 'unmap', 'multi_apply'
'init_dist', 'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs',
'unmap', 'multi_apply'
]
......@@ -4,9 +4,9 @@ from collections import OrderedDict
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.nn.utils import clip_grad
from mmcv.runner import Hook, OptimizerHook
from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors,
_take_tensors)
from mmcv.runner import OptimizerHook
def init_dist(launcher, backend='nccl', **kwargs):
......@@ -38,59 +38,52 @@ def _init_dist_slurm(backend, **kwargs):
raise NotImplementedError
# modified from
# https://github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py#L9
def all_reduce_coalesced(tensors):
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
world_size = dist.get_world_size()
for tp in buckets:
bucket = buckets[tp]
coalesced = _flatten_dense_tensors(bucket)
dist.all_reduce(coalesced)
coalesced.div_(world_size)
for buf, synced in zip(bucket,
_unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
def reduce_grads(model, coalesce=True):
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
dist.all_reduce(flat_tensors)
flat_tensors.div_(world_size)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
def allreduce_grads(model, coalesce=True, bucket_size_mb=-1):
grads = [
param.grad.data for param in model.parameters()
if param.requires_grad and param.grad is not None
]
world_size = dist.get_world_size()
if coalesce:
all_reduce_coalesced(grads)
_allreduce_coalesced(grads, world_size, bucket_size_mb)
else:
world_size = dist.get_world_size()
for tensor in grads:
dist.all_reduce(tensor.div_(world_size))
class DistOptimizerHook(OptimizerHook):
def __init__(self, grad_clip=None, coalesce=True):
def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
def after_train_iter(self, runner):
runner.optimizer.zero_grad()
runner.outputs['loss'].backward()
reduce_grads(runner.model, self.coalesce)
allreduce_grads(runner.model, self.coalesce, self.bucket_size_mb)
if self.grad_clip is not None:
clip_grad.clip_grad_norm_(
filter(lambda p: p.requires_grad, runner.model.parameters()),
**self.grad_clip)
self.clip_grads(runner.model.parameters())
runner.optimizer.step()
class DistSamplerSeedHook(Hook):
def before_epoch(self, runner):
runner.data_loader.sampler.set_epoch(runner.epoch)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment