From 1229095fc088a655aecafd37cc6586955c3d3be8 Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Wed, 10 Oct 2018 20:34:08 +0800 Subject: [PATCH] fix flake8 error in python 2 --- mmdet/core/evaluation/eval_hooks.py | 2 +- mmdet/models/detectors/base.py | 8 +++----- tools/test.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py index bec25ef..1402f7f 100644 --- a/mmdet/core/evaluation/eval_hooks.py +++ b/mmdet/core/evaluation/eval_hooks.py @@ -74,7 +74,7 @@ class DistEvalHook(Hook): # compute output with torch.no_grad(): result = runner.model( - **data_gpu, return_loss=False, rescale=True) + return_loss=False, rescale=True, **data_gpu) results[idx] = result batch_size = runner.world_size diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index d1b0fce..e617b0e 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -34,11 +34,9 @@ class BaseDetector(nn.Module): pass def extract_feats(self, imgs): - if isinstance(imgs, torch.Tensor): - return self.extract_feat(imgs) - elif isinstance(imgs, list): - for img in imgs: - yield self.extract_feat(img) + assert isinstance(imgs, list) + for img in imgs: + yield self.extract_feat(img) @abstractmethod def forward_train(self, imgs, img_metas, **kwargs): diff --git a/tools/test.py b/tools/test.py index 2552e7a..e1552e5 100644 --- a/tools/test.py +++ b/tools/test.py @@ -17,7 +17,7 @@ def single_test(model, data_loader, show=False): prog_bar = mmcv.ProgressBar(len(data_loader.dataset)) for i, data in enumerate(data_loader): with torch.no_grad(): - result = model(**data, return_loss=False, rescale=not show) + result = model(return_loss=False, rescale=not show, **data) results.append(result) if show: -- GitLab