From 7facc34fe765df932faf5a0d58aa745ac65c03ba Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Fri, 17 May 2019 02:13:29 -0700
Subject: [PATCH] Save class names in checkpoints and update the high-level
 inference APIs (#645)

* update the high-level inference api

* save classes in meta data and use it for visualization
---
 GETTING_STARTED.md             | 29 +++++------
 mmdet/apis/__init__.py         |  4 +-
 mmdet/apis/inference.py        | 94 ++++++++++++++++++++++++++++------
 mmdet/models/detectors/base.py |  8 +--
 tools/test.py                  |  5 +-
 tools/train.py                 | 12 +++--
 6 files changed, 105 insertions(+), 47 deletions(-)

diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md
index d965425..b5aac9d 100644
--- a/GETTING_STARTED.md
+++ b/GETTING_STARTED.md
@@ -62,28 +62,23 @@ python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \
 Here is an example of building the model and test given images.
 
 ```python
-import mmcv
-from mmcv.runner import load_checkpoint
-from mmdet.models import build_detector
-from mmdet.apis import inference_detector, show_result
+from mmdet.apis import init_detector, inference_detector, show_result
 
-cfg = mmcv.Config.fromfile('configs/faster_rcnn_r50_fpn_1x.py')
-cfg.model.pretrained = None
+config_file = 'configs/faster_rcnn_r50_fpn_1x.py'
+checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth'
 
-# construct the model and load checkpoint
-model = build_detector(cfg.model, test_cfg=cfg.test_cfg)
-_ = load_checkpoint(model, 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth')
+# build the model from a config file and a checkpoint file
+model = init_detector(config_file, checkpoint_file)
 
-# test a single image
-img = mmcv.imread('test.jpg')
-result = inference_detector(model, img, cfg)
-show_result(img, result)
+# test a single image and show the results
+img = 'test.jpg'  # or img = mmcv.imread(img), which will only load it once
+result = inference_detector(model, img)
+show_result(img, result, model.CLASSES)
 
-# test a list of images
+# test a list of images and write the results to image files
 imgs = ['test1.jpg', 'test2.jpg']
-for i, result in enumerate(inference_detector(model, imgs, cfg, device='cuda:0')):
-    print(i, imgs[i])
-    show_result(imgs[i], result)
+for i, result in enumerate(inference_detector(model, imgs, device='cuda:0')):
+    show_result(imgs[i], result, model.CLASSES, out_file='result_{}.jpg'.format(i))
 ```
 
 
diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py
index 030b7de..762f5ab 100644
--- a/mmdet/apis/__init__.py
+++ b/mmdet/apis/__init__.py
@@ -1,8 +1,8 @@
 from .env import init_dist, get_root_logger, set_random_seed
 from .train import train_detector
-from .inference import inference_detector, show_result
+from .inference import init_detector, inference_detector, show_result
 
 __all__ = [
     'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
-    'inference_detector', 'show_result'
+    'init_detector', 'inference_detector', 'show_result'
 ]
diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py
index f7ecb2f..d2f455e 100644
--- a/mmdet/apis/inference.py
+++ b/mmdet/apis/inference.py
@@ -1,11 +1,71 @@
+import warnings
+
 import mmcv
 import numpy as np
 import pycocotools.mask as maskUtils
 import torch
+from mmcv.runner import load_checkpoint
 
 from mmdet.core import get_classes
 from mmdet.datasets import to_tensor
 from mmdet.datasets.transforms import ImageTransform
+from mmdet.models import build_detector
+
+
+def init_detector(config, checkpoint=None, device='cuda:0'):
+    """Initialize a detector from config file.
+
+    Args:
+        config (str or :obj:`mmcv.Config`): Config file path or the config
+            object.
+        checkpoint (str, optional): Checkpoint path. If left as None, the model
+            will not load any weights.
+
+    Returns:
+        nn.Module: The constructed detector.
+    """
+    if isinstance(config, str):
+        config = mmcv.Config.fromfile(config)
+    elif not isinstance(config, mmcv.Config):
+        raise TypeError('config must be a filename or Config object, '
+                        'but got {}'.format(type(config)))
+    config.model.pretrained = None
+    model = build_detector(config.model, test_cfg=config.test_cfg)
+    if checkpoint is not None:
+        checkpoint = load_checkpoint(model, checkpoint)
+        if 'CLASSES' in checkpoint['meta']:
+            model.CLASSES = checkpoint['meta']['classes']
+        else:
+            warnings.warn('Class names are not saved in the checkpoint\'s '
+                          'meta data, use COCO classes by default.')
+            model.CLASSES = get_classes('coco')
+    model.cfg = config  # save the config in the model for convenience
+    model.to(device)
+    model.eval()
+    return model
+
+
+def inference_detector(model, imgs):
+    """Inference image(s) with the detector.
+
+    Args:
+        model (nn.Module): The loaded detector.
+        imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
+            images.
+
+    Returns:
+        If imgs is a str, a generator will be returned, otherwise return the
+        detection results directly.
+    """
+    cfg = model.cfg
+    img_transform = ImageTransform(
+        size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
+
+    device = next(model.parameters()).device  # model device
+    if not isinstance(imgs, list):
+        return _inference_single(model, imgs, img_transform, device)
+    else:
+        return _inference_generator(model, imgs, img_transform, device)
 
 
 def _prepare_data(img, img_transform, cfg, device):
@@ -26,34 +86,34 @@ def _prepare_data(img, img_transform, cfg, device):
     return dict(img=[img], img_meta=[img_meta])
 
 
-def _inference_single(model, img, img_transform, cfg, device):
+def _inference_single(model, img, img_transform, device):
     img = mmcv.imread(img)
-    data = _prepare_data(img, img_transform, cfg, device)
+    data = _prepare_data(img, img_transform, model.cfg, device)
     with torch.no_grad():
         result = model(return_loss=False, rescale=True, **data)
     return result
 
 
-def _inference_generator(model, imgs, img_transform, cfg, device):
+def _inference_generator(model, imgs, img_transform, device):
     for img in imgs:
-        yield _inference_single(model, img, img_transform, cfg, device)
-
+        yield _inference_single(model, img, img_transform, device)
 
-def inference_detector(model, imgs, cfg, device='cuda:0'):
-    img_transform = ImageTransform(
-        size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
-    model = model.to(device)
-    model.eval()
-
-    if not isinstance(imgs, list):
-        return _inference_single(model, imgs, img_transform, cfg, device)
-    else:
-        return _inference_generator(model, imgs, img_transform, cfg, device)
 
+# TODO: merge this method with the one in BaseDetector
+def show_result(img, result, class_names, score_thr=0.3, out_file=None):
+    """Visualize the detection results on the image.
 
-def show_result(img, result, dataset='coco', score_thr=0.3, out_file=None):
+    Args:
+        img (str or np.ndarray): Image filename or loaded image.
+        result (tuple[list] or list): The detection result, can be either
+            (bbox, segm) or just bbox.
+        class_names (list[str] or tuple[str]): A list of class names.
+        score_thr (float): The threshold to visualize the bboxes and masks.
+        out_file (str, optional): If specified, the visualization result will
+            be written to the out file instead of shown in a window.
+    """
+    assert isinstance(class_names, (tuple, list))
     img = mmcv.imread(img)
-    class_names = get_classes(dataset)
     if isinstance(result, tuple):
         bbox_result, segm_result = result
     else:
diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
index 8e2bbde..311ca90 100644
--- a/mmdet/models/detectors/base.py
+++ b/mmdet/models/detectors/base.py
@@ -89,7 +89,7 @@ class BaseDetector(nn.Module):
                     data,
                     result,
                     img_norm_cfg,
-                    dataset='coco',
+                    dataset=None,
                     score_thr=0.3):
         if isinstance(result, tuple):
             bbox_result, segm_result = result
@@ -101,9 +101,11 @@ class BaseDetector(nn.Module):
         imgs = tensor2imgs(img_tensor, **img_norm_cfg)
         assert len(imgs) == len(img_metas)
 
-        if isinstance(dataset, str):
+        if dataset is None:
+            class_names = self.CLASSES
+        elif isinstance(dataset, str):
             class_names = get_classes(dataset)
-        elif isinstance(dataset, (list, tuple)) or dataset is None:
+        elif isinstance(dataset, (list, tuple)):
             class_names = dataset
         else:
             raise TypeError(
diff --git a/tools/test.py b/tools/test.py
index f8c0e2d..dc94b6d 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -26,10 +26,7 @@ def single_gpu_test(model, data_loader, show=False):
         results.append(result)
 
         if show:
-            model.module.show_result(data,
-                                     result,
-                                     dataset.img_norm_cfg,
-                                     dataset=dataset.CLASSES)
+            model.module.show_result(data, result, dataset.img_norm_cfg)
 
         batch_size = data['img'][0].size(0)
         for _ in range(batch_size):
diff --git a/tools/train.py b/tools/train.py
index 73cb5f6..97a253f 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -52,10 +52,6 @@ def main():
     if args.resume_from is not None:
         cfg.resume_from = args.resume_from
     cfg.gpus = args.gpus
-    if cfg.checkpoint_config is not None:
-        # save mmdet version in checkpoints as meta data
-        cfg.checkpoint_config.meta = dict(
-            mmdet_version=__version__, config=cfg.text)
 
     # init distributed env first, since logger depends on the dist info.
     if args.launcher == 'none':
@@ -77,6 +73,14 @@ def main():
         cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
 
     train_dataset = get_dataset(cfg.data.train)
+    if cfg.checkpoint_config is not None:
+        # save mmdet version, config file content and class names in
+        # checkpoints as meta data
+        cfg.checkpoint_config.meta = dict(
+            mmdet_version=__version__, config=cfg.text,
+            classes=train_dataset.CLASSES)
+    # add an attribute for visualization convenience
+    model.CLASSES = train_dataset.CLASSES
     train_detector(
         model,
         train_dataset,
-- 
GitLab