From 7ef08d32c0e2f8585b07423c9e027338ca16486f Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Sun, 22 Dec 2019 21:15:16 +0800
Subject: [PATCH] use mmcv.init_dist (#1851)

---
 .pre-commit-config.yaml  |  4 +--
 mmdet/apis/__init__.py   |  9 +++---
 mmdet/apis/env.py        | 69 ----------------------------------------
 mmdet/apis/train.py      | 27 ++++++++++++++--
 requirements.txt         | 10 +++---
 tools/test.py            |  3 +-
 tools/test_robustness.py |  4 +--
 tools/train.py           |  4 +--
 8 files changed, 40 insertions(+), 90 deletions(-)
 delete mode 100644 mmdet/apis/env.py

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2fae06c..901104c 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -8,11 +8,11 @@ repos:
   hooks:
       - id: isort
 - repo: https://github.com/pre-commit/mirrors-yapf
-  rev: 80b9cd2f0f3b1f3456a77eff3ddbaf08f18c08ae
+  rev: v0.29.0
   hooks:
     - id: yapf
 - repo: https://github.com/pre-commit/pre-commit-hooks
-  rev: v2.3.0
+  rev: v2.4.0
   hooks:
     - id: flake8
     - id: trailing-whitespace
diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py
index 4cdf847..914307a 100644
--- a/mmdet/apis/__init__.py
+++ b/mmdet/apis/__init__.py
@@ -1,10 +1,9 @@
-from .env import get_root_logger, init_dist, set_random_seed
 from .inference import (async_inference_detector, inference_detector,
                         init_detector, show_result, show_result_pyplot)
-from .train import train_detector
+from .train import get_root_logger, set_random_seed, train_detector
 
 __all__ = [
-    'async_inference_detector', 'init_dist', 'get_root_logger',
-    'set_random_seed', 'train_detector', 'init_detector', 'inference_detector',
-    'show_result', 'show_result_pyplot'
+    'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
+    'async_inference_detector', 'inference_detector', 'show_result',
+    'show_result_pyplot'
 ]
diff --git a/mmdet/apis/env.py b/mmdet/apis/env.py
deleted file mode 100644
index 19b0f86..0000000
--- a/mmdet/apis/env.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import logging
-import os
-import random
-import subprocess
-
-import numpy as np
-import torch
-import torch.distributed as dist
-import torch.multiprocessing as mp
-from mmcv.runner import get_dist_info
-
-
-def init_dist(launcher, backend='nccl', **kwargs):
-    if mp.get_start_method(allow_none=True) is None:
-        mp.set_start_method('spawn')
-    if launcher == 'pytorch':
-        _init_dist_pytorch(backend, **kwargs)
-    elif launcher == 'mpi':
-        _init_dist_mpi(backend, **kwargs)
-    elif launcher == 'slurm':
-        _init_dist_slurm(backend, **kwargs)
-    else:
-        raise ValueError('Invalid launcher type: {}'.format(launcher))
-
-
-def _init_dist_pytorch(backend, **kwargs):
-    # TODO: use local_rank instead of rank % num_gpus
-    rank = int(os.environ['RANK'])
-    num_gpus = torch.cuda.device_count()
-    torch.cuda.set_device(rank % num_gpus)
-    dist.init_process_group(backend=backend, **kwargs)
-
-
-def _init_dist_mpi(backend, **kwargs):
-    raise NotImplementedError
-
-
-def _init_dist_slurm(backend, port=29500, **kwargs):
-    proc_id = int(os.environ['SLURM_PROCID'])
-    ntasks = int(os.environ['SLURM_NTASKS'])
-    node_list = os.environ['SLURM_NODELIST']
-    num_gpus = torch.cuda.device_count()
-    torch.cuda.set_device(proc_id % num_gpus)
-    addr = subprocess.getoutput(
-        'scontrol show hostname {} | head -n1'.format(node_list))
-    os.environ['MASTER_PORT'] = str(port)
-    os.environ['MASTER_ADDR'] = addr
-    os.environ['WORLD_SIZE'] = str(ntasks)
-    os.environ['RANK'] = str(proc_id)
-    dist.init_process_group(backend=backend)
-
-
-def set_random_seed(seed):
-    random.seed(seed)
-    np.random.seed(seed)
-    torch.manual_seed(seed)
-    torch.cuda.manual_seed_all(seed)
-
-
-def get_root_logger(log_level=logging.INFO):
-    logger = logging.getLogger()
-    if not logger.hasHandlers():
-        logging.basicConfig(
-            format='%(asctime)s - %(levelname)s - %(message)s',
-            level=log_level)
-    rank, _ = get_dist_info()
-    if rank != 0:
-        logger.setLevel('ERROR')
-    return logger
diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
index cdac16d..47320c6 100644
--- a/mmdet/apis/train.py
+++ b/mmdet/apis/train.py
@@ -1,18 +1,39 @@
-from __future__ import division
+import logging
+import random
 import re
 from collections import OrderedDict
 
+import numpy as np
 import torch
 import torch.distributed as dist
 from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
-from mmcv.runner import DistSamplerSeedHook, Runner, obj_from_dict
+from mmcv.runner import (DistSamplerSeedHook, Runner, get_dist_info,
+                         obj_from_dict)
 
 from mmdet import datasets
 from mmdet.core import (CocoDistEvalmAPHook, CocoDistEvalRecallHook,
                         DistEvalmAPHook, DistOptimizerHook, Fp16OptimizerHook)
 from mmdet.datasets import DATASETS, build_dataloader
 from mmdet.models import RPN
-from .env import get_root_logger
+
+
+def set_random_seed(seed):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+
+
+def get_root_logger(log_level=logging.INFO):
+    logger = logging.getLogger()
+    if not logger.hasHandlers():
+        logging.basicConfig(
+            format='%(asctime)s - %(levelname)s - %(message)s',
+            level=log_level)
+    rank, _ = get_dist_info()
+    if rank != 0:
+        logger.setLevel('ERROR')
+    return logger
 
 
 def parse_losses(losses):
diff --git a/requirements.txt b/requirements.txt
index 8a68f41..5cacde1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,10 +1,10 @@
-mmcv>=0.2.10
-numpy
+albumentations>=0.3.2
+imagecorruptions
 matplotlib
+mmcv>=0.2.15
+numpy
+pycocotools
 six
 terminaltables
-pycocotools
 torch>=1.1
 torchvision
-imagecorruptions
-albumentations>=0.3.2
\ No newline at end of file
diff --git a/tools/test.py b/tools/test.py
index 64dd733..b39cf13 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -9,9 +9,8 @@ import mmcv
 import torch
 import torch.distributed as dist
 from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
-from mmcv.runner import get_dist_info, load_checkpoint
+from mmcv.runner import get_dist_info, init_dist, load_checkpoint
 
-from mmdet.apis import init_dist
 from mmdet.core import coco_eval, results2json, wrap_fp16_model
 from mmdet.datasets import build_dataloader, build_dataset
 from mmdet.models import build_detector
diff --git a/tools/test_robustness.py b/tools/test_robustness.py
index c0489f3..fb58deb 100644
--- a/tools/test_robustness.py
+++ b/tools/test_robustness.py
@@ -10,13 +10,13 @@ import numpy as np
 import torch
 import torch.distributed as dist
 from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
-from mmcv.runner import get_dist_info, load_checkpoint
+from mmcv.runner import get_dist_info, init_dist, load_checkpoint
 from pycocotools.coco import COCO
 from pycocotools.cocoeval import COCOeval
 from robustness_eval import get_results
 
 from mmdet import datasets
-from mmdet.apis import init_dist, set_random_seed
+from mmdet.apis import set_random_seed
 from mmdet.core import (eval_map, fast_eval_recall, results2json,
                         wrap_fp16_model)
 from mmdet.datasets import build_dataloader, build_dataset
diff --git a/tools/train.py b/tools/train.py
index c939343..e3bbbde 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -4,10 +4,10 @@ import os
 
 import torch
 from mmcv import Config
+from mmcv.runner import init_dist
 
 from mmdet import __version__
-from mmdet.apis import (get_root_logger, init_dist, set_random_seed,
-                        train_detector)
+from mmdet.apis import get_root_logger, set_random_seed, train_detector
 from mmdet.datasets import build_dataset
 from mmdet.models import build_detector
 
-- 
GitLab