From 7ef08d32c0e2f8585b07423c9e027338ca16486f Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Sun, 22 Dec 2019 21:15:16 +0800 Subject: [PATCH] use mmcv.init_dist (#1851) --- .pre-commit-config.yaml | 4 +-- mmdet/apis/__init__.py | 9 +++--- mmdet/apis/env.py | 69 ---------------------------------------- mmdet/apis/train.py | 27 ++++++++++++++-- requirements.txt | 10 +++--- tools/test.py | 3 +- tools/test_robustness.py | 4 +-- tools/train.py | 4 +-- 8 files changed, 40 insertions(+), 90 deletions(-) delete mode 100644 mmdet/apis/env.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2fae06c..901104c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,11 +8,11 @@ repos: hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-yapf - rev: 80b9cd2f0f3b1f3456a77eff3ddbaf08f18c08ae + rev: v0.29.0 hooks: - id: yapf - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v2.4.0 hooks: - id: flake8 - id: trailing-whitespace diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py index 4cdf847..914307a 100644 --- a/mmdet/apis/__init__.py +++ b/mmdet/apis/__init__.py @@ -1,10 +1,9 @@ -from .env import get_root_logger, init_dist, set_random_seed from .inference import (async_inference_detector, inference_detector, init_detector, show_result, show_result_pyplot) -from .train import train_detector +from .train import get_root_logger, set_random_seed, train_detector __all__ = [ - 'async_inference_detector', 'init_dist', 'get_root_logger', - 'set_random_seed', 'train_detector', 'init_detector', 'inference_detector', - 'show_result', 'show_result_pyplot' + 'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector', + 'async_inference_detector', 'inference_detector', 'show_result', + 'show_result_pyplot' ] diff --git a/mmdet/apis/env.py b/mmdet/apis/env.py deleted file mode 100644 index 19b0f86..0000000 --- a/mmdet/apis/env.py +++ /dev/null @@ -1,69 +0,0 @@ -import logging -import os -import random -import subprocess - -import numpy as np -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from mmcv.runner import get_dist_info - - -def init_dist(launcher, backend='nccl', **kwargs): - if mp.get_start_method(allow_none=True) is None: - mp.set_start_method('spawn') - if launcher == 'pytorch': - _init_dist_pytorch(backend, **kwargs) - elif launcher == 'mpi': - _init_dist_mpi(backend, **kwargs) - elif launcher == 'slurm': - _init_dist_slurm(backend, **kwargs) - else: - raise ValueError('Invalid launcher type: {}'.format(launcher)) - - -def _init_dist_pytorch(backend, **kwargs): - # TODO: use local_rank instead of rank % num_gpus - rank = int(os.environ['RANK']) - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(rank % num_gpus) - dist.init_process_group(backend=backend, **kwargs) - - -def _init_dist_mpi(backend, **kwargs): - raise NotImplementedError - - -def _init_dist_slurm(backend, port=29500, **kwargs): - proc_id = int(os.environ['SLURM_PROCID']) - ntasks = int(os.environ['SLURM_NTASKS']) - node_list = os.environ['SLURM_NODELIST'] - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(proc_id % num_gpus) - addr = subprocess.getoutput( - 'scontrol show hostname {} | head -n1'.format(node_list)) - os.environ['MASTER_PORT'] = str(port) - os.environ['MASTER_ADDR'] = addr - os.environ['WORLD_SIZE'] = str(ntasks) - os.environ['RANK'] = str(proc_id) - dist.init_process_group(backend=backend) - - -def set_random_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_root_logger(log_level=logging.INFO): - logger = logging.getLogger() - if not logger.hasHandlers(): - logging.basicConfig( - format='%(asctime)s - %(levelname)s - %(message)s', - level=log_level) - rank, _ = get_dist_info() - if rank != 0: - logger.setLevel('ERROR') - return logger diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index cdac16d..47320c6 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -1,18 +1,39 @@ -from __future__ import division +import logging +import random import re from collections import OrderedDict +import numpy as np import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import DistSamplerSeedHook, Runner, obj_from_dict +from mmcv.runner import (DistSamplerSeedHook, Runner, get_dist_info, + obj_from_dict) from mmdet import datasets from mmdet.core import (CocoDistEvalmAPHook, CocoDistEvalRecallHook, DistEvalmAPHook, DistOptimizerHook, Fp16OptimizerHook) from mmdet.datasets import DATASETS, build_dataloader from mmdet.models import RPN -from .env import get_root_logger + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_root_logger(log_level=logging.INFO): + logger = logging.getLogger() + if not logger.hasHandlers(): + logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(message)s', + level=log_level) + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + return logger def parse_losses(losses): diff --git a/requirements.txt b/requirements.txt index 8a68f41..5cacde1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ -mmcv>=0.2.10 -numpy +albumentations>=0.3.2 +imagecorruptions matplotlib +mmcv>=0.2.15 +numpy +pycocotools six terminaltables -pycocotools torch>=1.1 torchvision -imagecorruptions -albumentations>=0.3.2 \ No newline at end of file diff --git a/tools/test.py b/tools/test.py index 64dd733..b39cf13 100644 --- a/tools/test.py +++ b/tools/test.py @@ -9,9 +9,8 @@ import mmcv import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import get_dist_info, load_checkpoint +from mmcv.runner import get_dist_info, init_dist, load_checkpoint -from mmdet.apis import init_dist from mmdet.core import coco_eval, results2json, wrap_fp16_model from mmdet.datasets import build_dataloader, build_dataset from mmdet.models import build_detector diff --git a/tools/test_robustness.py b/tools/test_robustness.py index c0489f3..fb58deb 100644 --- a/tools/test_robustness.py +++ b/tools/test_robustness.py @@ -10,13 +10,13 @@ import numpy as np import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import get_dist_info, load_checkpoint +from mmcv.runner import get_dist_info, init_dist, load_checkpoint from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval from robustness_eval import get_results from mmdet import datasets -from mmdet.apis import init_dist, set_random_seed +from mmdet.apis import set_random_seed from mmdet.core import (eval_map, fast_eval_recall, results2json, wrap_fp16_model) from mmdet.datasets import build_dataloader, build_dataset diff --git a/tools/train.py b/tools/train.py index c939343..e3bbbde 100644 --- a/tools/train.py +++ b/tools/train.py @@ -4,10 +4,10 @@ import os import torch from mmcv import Config +from mmcv.runner import init_dist from mmdet import __version__ -from mmdet.apis import (get_root_logger, init_dist, set_random_seed, - train_detector) +from mmdet.apis import get_root_logger, set_random_seed, train_detector from mmdet.datasets import build_dataset from mmdet.models import build_detector -- GitLab