From 8bf38df0b47094528c352f0c27a84b2c6531d57a Mon Sep 17 00:00:00 2001
From: Qiang Zhang <zhangtemplar@users.noreply.github.com>
Date: Thu, 27 Jun 2019 09:37:48 -0700
Subject: [PATCH] Only import torch.distributed when needed (#882)

* Fix an import error for `get_world_size` and `get_rank`

* Only import torch.distributed when needed

torch.distributed is only used in DistributedGroupSampler

* use `get_dist_info` to obtain world size and rank

`get_dist_info` from `mmcv.runner.utils` handles the problem of `distributed_c10d` doesn't exist.
---
 mmdet/datasets/loader/sampler.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mmdet/datasets/loader/sampler.py b/mmdet/datasets/loader/sampler.py
index 1e454b7..f45ba13 100644
--- a/mmdet/datasets/loader/sampler.py
+++ b/mmdet/datasets/loader/sampler.py
@@ -4,7 +4,7 @@ import math
 import torch
 import numpy as np
 
-from torch.distributed import get_world_size, get_rank
+from mmcv.runner.utils import get_dist_info
 from torch.utils.data import Sampler
 from torch.utils.data import DistributedSampler as _DistributedSampler
 
@@ -95,10 +95,11 @@ class DistributedGroupSampler(Sampler):
                  samples_per_gpu=1,
                  num_replicas=None,
                  rank=None):
+        _rank, _num_replicas = get_dist_info()
         if num_replicas is None:
-            num_replicas = get_world_size()
+            num_replicas = _num_replicas
         if rank is None:
-            rank = get_rank()
+            rank = _rank
         self.dataset = dataset
         self.samples_per_gpu = samples_per_gpu
         self.num_replicas = num_replicas
-- 
GitLab