From 4c1da63619fa5cc9126187767255d1c1f974e522 Mon Sep 17 00:00:00 2001 From: myownskyW7 <727032989@qq.com> Date: Thu, 11 Oct 2018 23:30:00 +0800 Subject: [PATCH] add high level api --- mmdet/api/__init__.py | 4 ++ mmdet/api/inference.py | 54 ++++++++++++++++++ mmdet/api/train.py | 120 +++++++++++++++++++++++++++++++++++++++ tools/train.py | 125 ++++------------------------------------- 4 files changed, 188 insertions(+), 115 deletions(-) create mode 100644 mmdet/api/__init__.py create mode 100644 mmdet/api/inference.py create mode 100644 mmdet/api/train.py diff --git a/mmdet/api/__init__.py b/mmdet/api/__init__.py new file mode 100644 index 0000000..970492f --- /dev/null +++ b/mmdet/api/__init__.py @@ -0,0 +1,4 @@ +from .train import train_detector +from .inference import inference_detector + +__all__ = ['train_detector', 'inference_detector'] diff --git a/mmdet/api/inference.py b/mmdet/api/inference.py new file mode 100644 index 0000000..47b7de3 --- /dev/null +++ b/mmdet/api/inference.py @@ -0,0 +1,54 @@ +import mmcv +import numpy as np +import torch + +from mmdet.datasets import to_tensor +from mmdet.datasets.transforms import ImageTransform +from mmdet.core import get_classes + + +def _prepare_data(img, img_transform, cfg, device): + ori_shape = img.shape + img, img_shape, pad_shape, scale_factor = img_transform( + img, scale=cfg.data.test.img_scale) + img = to_tensor(img).to(device).unsqueeze(0) + img_meta = [ + dict( + ori_shape=ori_shape, + img_shape=img_shape, + pad_shape=pad_shape, + scale_factor=scale_factor, + flip=False) + ] + return dict(img=[img], img_meta=[img_meta]) + + +def inference_detector(model, imgs, cfg, device='cuda:0'): + + imgs = imgs if isinstance(imgs, list) else [imgs] + img_transform = ImageTransform( + **cfg.img_norm_cfg, size_divisor=cfg.data.test.size_divisor) + model = model.to(device) + model.eval() + for img in imgs: + img = mmcv.imread(img) + data = _prepare_data(img, img_transform, cfg, device) + with torch.no_grad(): + result = model(**data, return_loss=False, rescale=True) + yield result + + +def show_result(img, result, dataset='coco', score_thr=0.3): + class_names = get_classes(dataset) + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(result) + ] + labels = np.concatenate(labels) + bboxes = np.vstack(result) + mmcv.imshow_det_bboxes( + img.copy(), + bboxes, + labels, + class_names=class_names, + score_thr=score_thr) diff --git a/mmdet/api/train.py b/mmdet/api/train.py new file mode 100644 index 0000000..28469a2 --- /dev/null +++ b/mmdet/api/train.py @@ -0,0 +1,120 @@ +from __future__ import division + +import logging +import random +from collections import OrderedDict + +import numpy as np +import torch +from mmcv.runner import Runner, DistSamplerSeedHook +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel + +from mmdet import __version__ +from mmdet.core import (init_dist, DistOptimizerHook, CocoDistEvalRecallHook, + CocoDistEvalmAPHook) +from mmdet.datasets import build_dataloader +from mmdet.models import RPN + + +def parse_losses(losses): + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + '{} is not a tensor or list of tensors'.format(loss_name)) + + loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) + + log_vars['loss'] = loss + for name in log_vars: + log_vars[name] = log_vars[name].item() + + return loss, log_vars + + +def batch_processor(model, data, train_mode): + losses = model(**data) + loss, log_vars = parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img'].data)) + + return outputs + + +def get_logger(log_level): + logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(message)s', level=log_level) + logger = logging.getLogger() + return logger + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def train_detector(model, dataset, cfg): + # save mmdet version in checkpoint as meta data + cfg.checkpoint_config.meta = dict( + mmdet_version=__version__, config=cfg.text) + + logger = get_logger(cfg.log_level) + + # set random seed if specified + if cfg.seed is not None: + logger.info('Set random seed to {}'.format(cfg.seed)) + set_random_seed(cfg.seed) + + # init distributed environment if necessary + if cfg.launcher == 'none': + dist = False + logger.info('Non-distributed training.') + else: + dist = True + init_dist(cfg.launcher, **cfg.dist_params) + if torch.distributed.get_rank() != 0: + logger.setLevel('ERROR') + logger.info('Distributed training.') + + # prepare data loaders + data_loaders = [ + build_dataloader(dataset, cfg.data.imgs_per_gpu, + cfg.data.workers_per_gpu, cfg.gpus, dist) + ] + + # put model on gpus + if dist: + model = MMDistributedDataParallel(model.cuda()) + else: + model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() + + # build runner + runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, + cfg.log_level) + + # register hooks + optimizer_config = DistOptimizerHook( + **cfg.optimizer_config) if dist else cfg.optimizer_config + runner.register_training_hooks(cfg.lr_config, optimizer_config, + cfg.checkpoint_config, cfg.log_config) + if dist: + runner.register_hook(DistSamplerSeedHook()) + # register eval hooks + if cfg.validate: + if isinstance(model.module, RPN): + runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) + elif cfg.data.val.type == 'CocoDataset': + runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) + + if cfg.resume_from: + runner.resume(cfg.resume_from) + elif cfg.load_from: + runner.load_checkpoint(cfg.load_from) + runner.run(data_loaders, cfg.workflow, cfg.total_epochs) \ No newline at end of file diff --git a/tools/train.py b/tools/train.py index 237ec2b..839f27c 100644 --- a/tools/train.py +++ b/tools/train.py @@ -1,65 +1,12 @@ from __future__ import division import argparse -import logging -import random -from collections import OrderedDict - -import numpy as np -import torch from mmcv import Config -from mmcv.runner import Runner, obj_from_dict, DistSamplerSeedHook -from mmcv.parallel import MMDataParallel, MMDistributedDataParallel - -from mmdet import datasets, __version__ -from mmdet.core import (init_dist, DistOptimizerHook, CocoDistEvalRecallHook, - CocoDistEvalmAPHook) -from mmdet.datasets import build_dataloader -from mmdet.models import build_detector, RPN - - -def parse_losses(losses): - log_vars = OrderedDict() - for loss_name, loss_value in losses.items(): - if isinstance(loss_value, torch.Tensor): - log_vars[loss_name] = loss_value.mean() - elif isinstance(loss_value, list): - log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) - else: - raise TypeError( - '{} is not a tensor or list of tensors'.format(loss_name)) - - loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) - - log_vars['loss'] = loss - for name in log_vars: - log_vars[name] = log_vars[name].item() - - return loss, log_vars - - -def batch_processor(model, data, train_mode): - losses = model(**data) - loss, log_vars = parse_losses(losses) - - outputs = dict( - loss=loss, log_vars=log_vars, num_samples=len(data['img'].data)) - - return outputs - - -def get_logger(log_level): - logging.basicConfig( - format='%(asctime)s - %(levelname)s - %(message)s', level=log_level) - logger = logging.getLogger() - return logger +from mmcv.runner import obj_from_dict - -def set_random_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) +from mmdet import datasets +from mmdet.api import train_detector +from mmdet.models import build_detector def parse_args(): @@ -86,71 +33,19 @@ def parse_args(): def main(): args = parse_args() - cfg = Config.fromfile(args.config) if args.work_dir is not None: cfg.work_dir = args.work_dir + cfg.validate = args.validate cfg.gpus = args.gpus - # save mmdet version in checkpoint as meta data - cfg.checkpoint_config.meta = dict( - mmdet_version=__version__, config=cfg.text) - - logger = get_logger(cfg.log_level) - - # set random seed if specified - if args.seed is not None: - logger.info('Set random seed to {}'.format(args.seed)) - set_random_seed(args.seed) - - # init distributed environment if necessary - if args.launcher == 'none': - dist = False - logger.info('Non-distributed training.') - else: - dist = True - init_dist(args.launcher, **cfg.dist_params) - if torch.distributed.get_rank() != 0: - logger.setLevel('ERROR') - logger.info('Distributed training.') - - # prepare data loaders - train_dataset = obj_from_dict(cfg.data.train, datasets) - data_loaders = [ - build_dataloader(train_dataset, cfg.data.imgs_per_gpu, - cfg.data.workers_per_gpu, cfg.gpus, dist) - ] - + cfg.seed = args.seed + cfg.launcher = args.launcher + cfg.local_rank = args.local_rank # build model model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) - if dist: - model = MMDistributedDataParallel(model.cuda()) - else: - model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() - - # build runner - runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, - cfg.log_level) - - # register hooks - optimizer_config = DistOptimizerHook( - **cfg.optimizer_config) if dist else cfg.optimizer_config - runner.register_training_hooks(cfg.lr_config, optimizer_config, - cfg.checkpoint_config, cfg.log_config) - if dist: - runner.register_hook(DistSamplerSeedHook()) - # register eval hooks - if args.validate: - if isinstance(model.module, RPN): - runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) - elif cfg.data.val.type == 'CocoDataset': - runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) - - if cfg.resume_from: - runner.resume(cfg.resume_from) - elif cfg.load_from: - runner.load_checkpoint(cfg.load_from) - runner.run(data_loaders, cfg.workflow, cfg.total_epochs) + train_dataset = obj_from_dict(cfg.data.train, datasets) + train_detector(model, train_dataset, cfg) if __name__ == '__main__': -- GitLab