From 7facc34fe765df932faf5a0d58aa745ac65c03ba Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Fri, 17 May 2019 02:13:29 -0700 Subject: [PATCH] Save class names in checkpoints and update the high-level inference APIs (#645) * update the high-level inference api * save classes in meta data and use it for visualization --- GETTING_STARTED.md | 29 +++++------ mmdet/apis/__init__.py | 4 +- mmdet/apis/inference.py | 94 ++++++++++++++++++++++++++++------ mmdet/models/detectors/base.py | 8 +-- tools/test.py | 5 +- tools/train.py | 12 +++-- 6 files changed, 105 insertions(+), 47 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index d965425..b5aac9d 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -62,28 +62,23 @@ python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \ Here is an example of building the model and test given images. ```python -import mmcv -from mmcv.runner import load_checkpoint -from mmdet.models import build_detector -from mmdet.apis import inference_detector, show_result +from mmdet.apis import init_detector, inference_detector, show_result -cfg = mmcv.Config.fromfile('configs/faster_rcnn_r50_fpn_1x.py') -cfg.model.pretrained = None +config_file = 'configs/faster_rcnn_r50_fpn_1x.py' +checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth' -# construct the model and load checkpoint -model = build_detector(cfg.model, test_cfg=cfg.test_cfg) -_ = load_checkpoint(model, 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth') +# build the model from a config file and a checkpoint file +model = init_detector(config_file, checkpoint_file) -# test a single image -img = mmcv.imread('test.jpg') -result = inference_detector(model, img, cfg) -show_result(img, result) +# test a single image and show the results +img = 'test.jpg' # or img = mmcv.imread(img), which will only load it once +result = inference_detector(model, img) +show_result(img, result, model.CLASSES) -# test a list of images +# test a list of images and write the results to image files imgs = ['test1.jpg', 'test2.jpg'] -for i, result in enumerate(inference_detector(model, imgs, cfg, device='cuda:0')): - print(i, imgs[i]) - show_result(imgs[i], result) +for i, result in enumerate(inference_detector(model, imgs, device='cuda:0')): + show_result(imgs[i], result, model.CLASSES, out_file='result_{}.jpg'.format(i)) ``` diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py index 030b7de..762f5ab 100644 --- a/mmdet/apis/__init__.py +++ b/mmdet/apis/__init__.py @@ -1,8 +1,8 @@ from .env import init_dist, get_root_logger, set_random_seed from .train import train_detector -from .inference import inference_detector, show_result +from .inference import init_detector, inference_detector, show_result __all__ = [ 'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector', - 'inference_detector', 'show_result' + 'init_detector', 'inference_detector', 'show_result' ] diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py index f7ecb2f..d2f455e 100644 --- a/mmdet/apis/inference.py +++ b/mmdet/apis/inference.py @@ -1,11 +1,71 @@ +import warnings + import mmcv import numpy as np import pycocotools.mask as maskUtils import torch +from mmcv.runner import load_checkpoint from mmdet.core import get_classes from mmdet.datasets import to_tensor from mmdet.datasets.transforms import ImageTransform +from mmdet.models import build_detector + + +def init_detector(config, checkpoint=None, device='cuda:0'): + """Initialize a detector from config file. + + Args: + config (str or :obj:`mmcv.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + + Returns: + nn.Module: The constructed detector. + """ + if isinstance(config, str): + config = mmcv.Config.fromfile(config) + elif not isinstance(config, mmcv.Config): + raise TypeError('config must be a filename or Config object, ' + 'but got {}'.format(type(config))) + config.model.pretrained = None + model = build_detector(config.model, test_cfg=config.test_cfg) + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint) + if 'CLASSES' in checkpoint['meta']: + model.CLASSES = checkpoint['meta']['classes'] + else: + warnings.warn('Class names are not saved in the checkpoint\'s ' + 'meta data, use COCO classes by default.') + model.CLASSES = get_classes('coco') + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +def inference_detector(model, imgs): + """Inference image(s) with the detector. + + Args: + model (nn.Module): The loaded detector. + imgs (str/ndarray or list[str/ndarray]): Either image files or loaded + images. + + Returns: + If imgs is a str, a generator will be returned, otherwise return the + detection results directly. + """ + cfg = model.cfg + img_transform = ImageTransform( + size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg) + + device = next(model.parameters()).device # model device + if not isinstance(imgs, list): + return _inference_single(model, imgs, img_transform, device) + else: + return _inference_generator(model, imgs, img_transform, device) def _prepare_data(img, img_transform, cfg, device): @@ -26,34 +86,34 @@ def _prepare_data(img, img_transform, cfg, device): return dict(img=[img], img_meta=[img_meta]) -def _inference_single(model, img, img_transform, cfg, device): +def _inference_single(model, img, img_transform, device): img = mmcv.imread(img) - data = _prepare_data(img, img_transform, cfg, device) + data = _prepare_data(img, img_transform, model.cfg, device) with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) return result -def _inference_generator(model, imgs, img_transform, cfg, device): +def _inference_generator(model, imgs, img_transform, device): for img in imgs: - yield _inference_single(model, img, img_transform, cfg, device) - + yield _inference_single(model, img, img_transform, device) -def inference_detector(model, imgs, cfg, device='cuda:0'): - img_transform = ImageTransform( - size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg) - model = model.to(device) - model.eval() - - if not isinstance(imgs, list): - return _inference_single(model, imgs, img_transform, cfg, device) - else: - return _inference_generator(model, imgs, img_transform, cfg, device) +# TODO: merge this method with the one in BaseDetector +def show_result(img, result, class_names, score_thr=0.3, out_file=None): + """Visualize the detection results on the image. -def show_result(img, result, dataset='coco', score_thr=0.3, out_file=None): + Args: + img (str or np.ndarray): Image filename or loaded image. + result (tuple[list] or list): The detection result, can be either + (bbox, segm) or just bbox. + class_names (list[str] or tuple[str]): A list of class names. + score_thr (float): The threshold to visualize the bboxes and masks. + out_file (str, optional): If specified, the visualization result will + be written to the out file instead of shown in a window. + """ + assert isinstance(class_names, (tuple, list)) img = mmcv.imread(img) - class_names = get_classes(dataset) if isinstance(result, tuple): bbox_result, segm_result = result else: diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index 8e2bbde..311ca90 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -89,7 +89,7 @@ class BaseDetector(nn.Module): data, result, img_norm_cfg, - dataset='coco', + dataset=None, score_thr=0.3): if isinstance(result, tuple): bbox_result, segm_result = result @@ -101,9 +101,11 @@ class BaseDetector(nn.Module): imgs = tensor2imgs(img_tensor, **img_norm_cfg) assert len(imgs) == len(img_metas) - if isinstance(dataset, str): + if dataset is None: + class_names = self.CLASSES + elif isinstance(dataset, str): class_names = get_classes(dataset) - elif isinstance(dataset, (list, tuple)) or dataset is None: + elif isinstance(dataset, (list, tuple)): class_names = dataset else: raise TypeError( diff --git a/tools/test.py b/tools/test.py index f8c0e2d..dc94b6d 100644 --- a/tools/test.py +++ b/tools/test.py @@ -26,10 +26,7 @@ def single_gpu_test(model, data_loader, show=False): results.append(result) if show: - model.module.show_result(data, - result, - dataset.img_norm_cfg, - dataset=dataset.CLASSES) + model.module.show_result(data, result, dataset.img_norm_cfg) batch_size = data['img'][0].size(0) for _ in range(batch_size): diff --git a/tools/train.py b/tools/train.py index 73cb5f6..97a253f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -52,10 +52,6 @@ def main(): if args.resume_from is not None: cfg.resume_from = args.resume_from cfg.gpus = args.gpus - if cfg.checkpoint_config is not None: - # save mmdet version in checkpoints as meta data - cfg.checkpoint_config.meta = dict( - mmdet_version=__version__, config=cfg.text) # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': @@ -77,6 +73,14 @@ def main(): cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) train_dataset = get_dataset(cfg.data.train) + if cfg.checkpoint_config is not None: + # save mmdet version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmdet_version=__version__, config=cfg.text, + classes=train_dataset.CLASSES) + # add an attribute for visualization convenience + model.CLASSES = train_dataset.CLASSES train_detector( model, train_dataset, -- GitLab