From d6a724fbc606bb750a1682065b1260a6131c17e9 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Fri, 12 Oct 2018 15:27:33 +0800
Subject: [PATCH] update inference api

---
 mmdet/apis/__init__.py  |  4 ++--
 mmdet/apis/inference.py | 27 +++++++++++++++++++--------
 2 files changed, 21 insertions(+), 10 deletions(-)

diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py
index 39544f2..030b7de 100644
--- a/mmdet/apis/__init__.py
+++ b/mmdet/apis/__init__.py
@@ -1,8 +1,8 @@
 from .env import init_dist, get_root_logger, set_random_seed
 from .train import train_detector
-from .inference import inference_detector
+from .inference import inference_detector, show_result
 
 __all__ = [
     'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
-    'inference_detector'
+    'inference_detector', 'show_result'
 ]
diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py
index 0addd59..a87323c 100644
--- a/mmdet/apis/inference.py
+++ b/mmdet/apis/inference.py
@@ -23,19 +23,29 @@ def _prepare_data(img, img_transform, cfg, device):
     return dict(img=[img], img_meta=[img_meta])
 
 
-def inference_detector(model, imgs, cfg, device='cuda:0'):
+def _inference_single(model, img, img_transform, cfg, device):
+    img = mmcv.imread(img)
+    data = _prepare_data(img, img_transform, cfg, device)
+    with torch.no_grad():
+        result = model(return_loss=False, rescale=True, **data)
+    return result
+
+
+def _inference_generator(model, imgs, img_transform, cfg, device):
+    for img in imgs:
+        yield _inference_single(model, img, img_transform, cfg, device)
 
-    imgs = imgs if isinstance(imgs, list) else [imgs]
+
+def inference_detector(model, imgs, cfg, device='cuda:0'):
     img_transform = ImageTransform(
         size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
     model = model.to(device)
     model.eval()
-    for img in imgs:
-        img = mmcv.imread(img)
-        data = _prepare_data(img, img_transform, cfg, device)
-        with torch.no_grad():
-            result = model(return_loss=False, rescale=True, **data)
-        yield result
+
+    if not isinstance(imgs, list):
+        return _inference_single(model, imgs, img_transform, cfg, device)
+    else:
+        return _inference_generator(model, imgs, img_transform, cfg, device)
 
 
 def show_result(img, result, dataset='coco', score_thr=0.3):
@@ -46,6 +56,7 @@ def show_result(img, result, dataset='coco', score_thr=0.3):
     ]
     labels = np.concatenate(labels)
     bboxes = np.vstack(result)
+    img = mmcv.imread(img)
     mmcv.imshow_det_bboxes(
         img.copy(),
         bboxes,
-- 
GitLab