diff --git a/mmdet/core/loss/losses.py b/mmdet/core/loss/losses.py index 1e32fca5606bd686bcd5c9ddea114971b77bed61..b2082002af82677d464a9e1bdddc9dc57a6129e0 100644 --- a/mmdet/core/loss/losses.py +++ b/mmdet/core/loss/losses.py @@ -92,6 +92,8 @@ def accuracy(pred, target, topk=1): if isinstance(topk, int): topk = (topk, ) return_single = True + else: + return_single = False maxk = max(topk) _, pred_label = pred.topk(maxk, 1, True, True) diff --git a/tools/test.py b/tools/test.py index d0537b3c64ebc51ddf410d8064278ba1da8eaa37..8aa23ea7ec8d1758fd06df6c47e765026dcd071e 100644 --- a/tools/test.py +++ b/tools/test.py @@ -23,7 +23,7 @@ def single_test(model, data_loader, show=False): if show: model.module.show_result(data, result, dataset.img_norm_cfg, - dataset.CLASSES) + dataset=dataset.CLASSES) batch_size = data['img'][0].size(0) for _ in range(batch_size):