Skip to content
Snippets Groups Projects
Unverified Commit 54b54d88 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #19 from hellock/dev

Update inference APIs
parents abc440fc 459d5ebc
No related branches found
No related tags found
No related merge requests found
from .env import init_dist, get_root_logger, set_random_seed
from .train import train_detector
from .inference import inference_detector
from .inference import inference_detector, show_result
__all__ = [
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
'inference_detector'
'inference_detector', 'show_result'
]
......@@ -23,19 +23,29 @@ def _prepare_data(img, img_transform, cfg, device):
return dict(img=[img], img_meta=[img_meta])
def inference_detector(model, imgs, cfg, device='cuda:0'):
def _inference_single(model, img, img_transform, cfg, device):
img = mmcv.imread(img)
data = _prepare_data(img, img_transform, cfg, device)
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result
def _inference_generator(model, imgs, img_transform, cfg, device):
for img in imgs:
yield _inference_single(model, img, img_transform, cfg, device)
imgs = imgs if isinstance(imgs, list) else [imgs]
def inference_detector(model, imgs, cfg, device='cuda:0'):
img_transform = ImageTransform(
size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
model = model.to(device)
model.eval()
for img in imgs:
img = mmcv.imread(img)
data = _prepare_data(img, img_transform, cfg, device)
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
yield result
if not isinstance(imgs, list):
return _inference_single(model, imgs, img_transform, cfg, device)
else:
return _inference_generator(model, imgs, img_transform, cfg, device)
def show_result(img, result, dataset='coco', score_thr=0.3):
......@@ -46,6 +56,7 @@ def show_result(img, result, dataset='coco', score_thr=0.3):
]
labels = np.concatenate(labels)
bboxes = np.vstack(result)
img = mmcv.imread(img)
mmcv.imshow_det_bboxes(
img.copy(),
bboxes,
......
......@@ -2,7 +2,7 @@ from mmcv.runner import obj_from_dict
from torch import nn
from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads,
mask_heads, detectors)
mask_heads)
__all__ = [
'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
......@@ -48,4 +48,5 @@ def build_mask_head(cfg):
def build_detector(cfg, train_cfg=None, test_cfg=None):
from . import detectors
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
......@@ -48,8 +48,8 @@ class RPNHead(nn.Module):
self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = anchor_strides.copy(
) if anchor_base_sizes is None else anchor_base_sizes
self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment