Skip to content
Snippets Groups Projects
Commit c76a95ec authored by Kai Chen's avatar Kai Chen
Browse files

reorganize the training api

parent 2a43cc7d
No related branches found
No related tags found
No related merge requests found
from .env import init_dist, get_root_logger, set_random_seed
from .train import train_detector
from .inference import inference_detector
__all__ = ['train_detector', 'inference_detector']
__all__ = [
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
'inference_detector'
]
import logging
import os
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from mmcv.runner import get_dist_info
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError('Invalid launcher type: {}'.format(launcher))
def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError
def _init_dist_slurm(backend, **kwargs):
raise NotImplementedError
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_root_logger(log_level=logging.INFO):
logger = logging.getLogger()
if not logger.hasHandlers():
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s',
level=log_level)
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
return logger
from __future__ import division
import logging
import random
from collections import OrderedDict
......@@ -9,11 +8,11 @@ import torch
from mmcv.runner import Runner, DistSamplerSeedHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmdet import __version__
from mmdet.core import (init_dist, DistOptimizerHook, CocoDistEvalRecallHook,
from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook,
CocoDistEvalmAPHook)
from mmdet.datasets import build_dataloader
from mmdet.models import RPN
from .env import get_root_logger
def parse_losses(losses):
......@@ -46,13 +45,6 @@ def batch_processor(model, data, train_mode):
return outputs
def get_logger(log_level):
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
logger = logging.getLogger()
return logger
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
......@@ -60,58 +52,72 @@ def set_random_seed(seed):
torch.cuda.manual_seed_all(seed)
def train_detector(model, dataset, cfg):
# save mmdet version in checkpoint as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, config=cfg.text)
logger = get_logger(cfg.log_level)
# set random seed if specified
if cfg.seed is not None:
logger.info('Set random seed to {}'.format(cfg.seed))
set_random_seed(cfg.seed)
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
logger=None):
if logger is None:
logger = get_root_logger(cfg.log_level)
# init distributed environment if necessary
if cfg.launcher == 'none':
dist = False
logger.info('Non-distributed training.')
# start training
if distributed:
_dist_train(model, dataset, cfg, validate=validate)
else:
dist = True
init_dist(cfg.launcher, **cfg.dist_params)
if torch.distributed.get_rank() != 0:
logger.setLevel('ERROR')
logger.info('Distributed training.')
_non_dist_train(model, dataset, cfg, validate=validate)
def _dist_train(model, dataset, cfg, validate=False):
# prepare data loaders
data_loaders = [
build_dataloader(dataset, cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu, cfg.gpus, dist)
build_dataloader(
dataset,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
dist=True)
]
# put model on gpus
if dist:
model = MMDistributedDataParallel(model.cuda())
else:
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
model = MMDistributedDataParallel(model.cuda())
# build runner
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
cfg.log_level)
# register hooks
optimizer_config = DistOptimizerHook(
**cfg.optimizer_config) if dist else cfg.optimizer_config
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
if dist:
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if cfg.validate:
if isinstance(model.module, RPN):
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
elif cfg.data.val.type == 'CocoDataset':
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
if isinstance(model.module, RPN):
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
elif cfg.data.val.type == 'CocoDataset':
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
def _non_dist_train(model, dataset, cfg, validate=False):
# prepare data loaders
data_loaders = [
build_dataloader(
dataset,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False)
]
# put model on gpus
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
# build runner
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
cfg.log_level)
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
cfg.checkpoint_config, cfg.log_config)
if cfg.resume_from:
runner.resume(cfg.resume_from)
......
from .dist_utils import init_dist, allreduce_grads, DistOptimizerHook
from .dist_utils import allreduce_grads, DistOptimizerHook
from .misc import tensor2imgs, unmap, multi_apply
__all__ = [
'init_dist', 'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs',
'unmap', 'multi_apply'
'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs', 'unmap',
'multi_apply'
]
import os
from collections import OrderedDict
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors,
_take_tensors)
from mmcv.runner import OptimizerHook
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError('Invalid launcher type: {}'.format(launcher))
def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError
def _init_dist_slurm(backend, **kwargs):
raise NotImplementedError
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
......
......@@ -15,7 +15,7 @@ resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
def build_dataloader(dataset,
imgs_per_gpu,
workers_per_gpu,
num_gpus,
num_gpus=1,
dist=True,
**kwargs):
if dist:
......
......@@ -4,8 +4,9 @@ import argparse
from mmcv import Config
from mmcv.runner import obj_from_dict
from mmdet import datasets
from mmdet.api import train_detector
from mmdet import datasets, __version__
from mmdet.api import (train_detector, init_dist, get_root_logger,
set_random_seed)
from mmdet.models import build_detector
......@@ -16,10 +17,14 @@ def parse_args():
parser.add_argument(
'--validate',
action='store_true',
help='whether to add a validate phase')
help='whether to evaluate the checkpoint during training')
parser.add_argument(
'--gpus', type=int, default=1, help='number of gpus to use')
parser.add_argument('--seed', type=int, help='random seed')
'--gpus',
type=int,
default=1,
help='number of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
......@@ -33,19 +38,43 @@ def parse_args():
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# update configs according to CLI args
if args.work_dir is not None:
cfg.work_dir = args.work_dir
cfg.validate = args.validate
cfg.gpus = args.gpus
cfg.seed = args.seed
cfg.launcher = args.launcher
cfg.local_rank = args.local_rank
# build model
if cfg.checkpoint_config is not None:
# save mmdet version in checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, config=cfg.text)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# init logger before other steps
logger = get_root_logger(cfg.log_level)
logger.info('Distributed training: {}'.format(distributed))
# set random seeds
if args.seed is not None:
logger.info('Set random seed to {}'.format(args.seed))
set_random_seed(args.seed)
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
train_dataset = obj_from_dict(cfg.data.train, datasets)
train_detector(model, train_dataset, cfg)
train_detector(
model,
train_dataset,
cfg,
distributed=distributed,
validate=args.validate,
logger=logger)
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment