diff --git a/mmdet/api/inference.py b/mmdet/api/inference.py index d452c665a9f1d74a13d709b92c817d9e5789a47c..0addd598e5043220e9bbc1ad594c757fafa75481 100644 --- a/mmdet/api/inference.py +++ b/mmdet/api/inference.py @@ -34,7 +34,7 @@ def inference_detector(model, imgs, cfg, device='cuda:0'): img = mmcv.imread(img) data = _prepare_data(img, img_transform, cfg, device) with torch.no_grad(): - result = model(**data, return_loss=False, rescale=True) + result = model(return_loss=False, rescale=True, **data) yield result