diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py index 39544f2bf41a46708943e6f1672fef0b1df03e11..030b7de41026755359535cc309e39c7b4e0efb66 100644 --- a/mmdet/apis/__init__.py +++ b/mmdet/apis/__init__.py @@ -1,8 +1,8 @@ from .env import init_dist, get_root_logger, set_random_seed from .train import train_detector -from .inference import inference_detector +from .inference import inference_detector, show_result __all__ = [ 'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector', - 'inference_detector' + 'inference_detector', 'show_result' ] diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py index 0addd598e5043220e9bbc1ad594c757fafa75481..a87323cee1aba6e97c75f2a563c5337bc5fe32ff 100644 --- a/mmdet/apis/inference.py +++ b/mmdet/apis/inference.py @@ -23,19 +23,29 @@ def _prepare_data(img, img_transform, cfg, device): return dict(img=[img], img_meta=[img_meta]) -def inference_detector(model, imgs, cfg, device='cuda:0'): +def _inference_single(model, img, img_transform, cfg, device): + img = mmcv.imread(img) + data = _prepare_data(img, img_transform, cfg, device) + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + return result + + +def _inference_generator(model, imgs, img_transform, cfg, device): + for img in imgs: + yield _inference_single(model, img, img_transform, cfg, device) - imgs = imgs if isinstance(imgs, list) else [imgs] + +def inference_detector(model, imgs, cfg, device='cuda:0'): img_transform = ImageTransform( size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg) model = model.to(device) model.eval() - for img in imgs: - img = mmcv.imread(img) - data = _prepare_data(img, img_transform, cfg, device) - with torch.no_grad(): - result = model(return_loss=False, rescale=True, **data) - yield result + + if not isinstance(imgs, list): + return _inference_single(model, imgs, img_transform, cfg, device) + else: + return _inference_generator(model, imgs, img_transform, cfg, device) def show_result(img, result, dataset='coco', score_thr=0.3): @@ -46,6 +56,7 @@ def show_result(img, result, dataset='coco', score_thr=0.3): ] labels = np.concatenate(labels) bboxes = np.vstack(result) + img = mmcv.imread(img) mmcv.imshow_det_bboxes( img.copy(), bboxes, diff --git a/mmdet/models/builder.py b/mmdet/models/builder.py index bdf0ac3d16f9aadb194f944b3f7c4dd1a741e8cd..ee5ae0b14b01e147f5f9199141709bdac4dbe0af 100644 --- a/mmdet/models/builder.py +++ b/mmdet/models/builder.py @@ -2,7 +2,7 @@ from mmcv.runner import obj_from_dict from torch import nn from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads, - mask_heads, detectors) + mask_heads) __all__ = [ 'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor', @@ -48,4 +48,5 @@ def build_mask_head(cfg): def build_detector(cfg, train_cfg=None, test_cfg=None): + from . import detectors return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/mmdet/models/rpn_heads/rpn_head.py b/mmdet/models/rpn_heads/rpn_head.py index e67d7ae973f05c60c8e226009cfb4234c0894f69..61e6e199ac0407bd23226701e3117c02ec16171d 100644 --- a/mmdet/models/rpn_heads/rpn_head.py +++ b/mmdet/models/rpn_heads/rpn_head.py @@ -48,8 +48,8 @@ class RPNHead(nn.Module): self.anchor_scales = anchor_scales self.anchor_ratios = anchor_ratios self.anchor_strides = anchor_strides - self.anchor_base_sizes = anchor_strides.copy( - ) if anchor_base_sizes is None else anchor_base_sizes + self.anchor_base_sizes = list( + anchor_strides) if anchor_base_sizes is None else anchor_base_sizes self.target_means = target_means self.target_stds = target_stds self.use_sigmoid_cls = use_sigmoid_cls