Skip to content
Snippets Groups Projects
Commit d6a724fb authored by Kai Chen's avatar Kai Chen
Browse files

update inference api

parent 2507eb6f
No related branches found
No related tags found
No related merge requests found
from .env import init_dist, get_root_logger, set_random_seed from .env import init_dist, get_root_logger, set_random_seed
from .train import train_detector from .train import train_detector
from .inference import inference_detector from .inference import inference_detector, show_result
__all__ = [ __all__ = [
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector', 'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
'inference_detector' 'inference_detector', 'show_result'
] ]
...@@ -23,19 +23,29 @@ def _prepare_data(img, img_transform, cfg, device): ...@@ -23,19 +23,29 @@ def _prepare_data(img, img_transform, cfg, device):
return dict(img=[img], img_meta=[img_meta]) return dict(img=[img], img_meta=[img_meta])
def inference_detector(model, imgs, cfg, device='cuda:0'): def _inference_single(model, img, img_transform, cfg, device):
img = mmcv.imread(img)
data = _prepare_data(img, img_transform, cfg, device)
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result
def _inference_generator(model, imgs, img_transform, cfg, device):
for img in imgs:
yield _inference_single(model, img, img_transform, cfg, device)
imgs = imgs if isinstance(imgs, list) else [imgs]
def inference_detector(model, imgs, cfg, device='cuda:0'):
img_transform = ImageTransform( img_transform = ImageTransform(
size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg) size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
model = model.to(device) model = model.to(device)
model.eval() model.eval()
for img in imgs:
img = mmcv.imread(img) if not isinstance(imgs, list):
data = _prepare_data(img, img_transform, cfg, device) return _inference_single(model, imgs, img_transform, cfg, device)
with torch.no_grad(): else:
result = model(return_loss=False, rescale=True, **data) return _inference_generator(model, imgs, img_transform, cfg, device)
yield result
def show_result(img, result, dataset='coco', score_thr=0.3): def show_result(img, result, dataset='coco', score_thr=0.3):
...@@ -46,6 +56,7 @@ def show_result(img, result, dataset='coco', score_thr=0.3): ...@@ -46,6 +56,7 @@ def show_result(img, result, dataset='coco', score_thr=0.3):
] ]
labels = np.concatenate(labels) labels = np.concatenate(labels)
bboxes = np.vstack(result) bboxes = np.vstack(result)
img = mmcv.imread(img)
mmcv.imshow_det_bboxes( mmcv.imshow_det_bboxes(
img.copy(), img.copy(),
bboxes, bboxes,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment