From 8bf38df0b47094528c352f0c27a84b2c6531d57a Mon Sep 17 00:00:00 2001 From: Qiang Zhang <zhangtemplar@users.noreply.github.com> Date: Thu, 27 Jun 2019 09:37:48 -0700 Subject: [PATCH] Only import torch.distributed when needed (#882) * Fix an import error for `get_world_size` and `get_rank` * Only import torch.distributed when needed torch.distributed is only used in DistributedGroupSampler * use `get_dist_info` to obtain world size and rank `get_dist_info` from `mmcv.runner.utils` handles the problem of `distributed_c10d` doesn't exist. --- mmdet/datasets/loader/sampler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mmdet/datasets/loader/sampler.py b/mmdet/datasets/loader/sampler.py index 1e454b7..f45ba13 100644 --- a/mmdet/datasets/loader/sampler.py +++ b/mmdet/datasets/loader/sampler.py @@ -4,7 +4,7 @@ import math import torch import numpy as np -from torch.distributed import get_world_size, get_rank +from mmcv.runner.utils import get_dist_info from torch.utils.data import Sampler from torch.utils.data import DistributedSampler as _DistributedSampler @@ -95,10 +95,11 @@ class DistributedGroupSampler(Sampler): samples_per_gpu=1, num_replicas=None, rank=None): + _rank, _num_replicas = get_dist_info() if num_replicas is None: - num_replicas = get_world_size() + num_replicas = _num_replicas if rank is None: - rank = get_rank() + rank = _rank self.dataset = dataset self.samples_per_gpu = samples_per_gpu self.num_replicas = num_replicas -- GitLab