Skip to content
Snippets Groups Projects
Commit 69e93f6f authored by donglee's avatar donglee Committed by Kai Chen
Browse files

Add tta to HTC and Cascade RCNN (#1251)

* add tta to HTC and Caccade RCNN

* format file with yapf

* fix import error with isort

* Update htc.py

* Update cascade_rcnn.py

* fix bug

* delete some redundant codes
parent ebf499f2
No related branches found
No related tags found
No related merge requests found
...@@ -3,8 +3,9 @@ from __future__ import division ...@@ -3,8 +3,9 @@ from __future__ import division
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmdet.core import (bbox2result, bbox2roi, build_assigner, build_sampler, from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
merge_aug_masks) build_sampler, merge_aug_bboxes, merge_aug_masks,
multiclass_nms)
from .. import builder from .. import builder
from ..registry import DETECTORS from ..registry import DETECTORS
from .base import BaseDetector from .base import BaseDetector
...@@ -399,8 +400,110 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -399,8 +400,110 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
return results return results
def aug_test(self, img, img_meta, proposals=None, rescale=False): def aug_test(self, imgs, img_metas, proposals=None, rescale=False):
raise NotImplementedError """Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
# recompute feats to save memory
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
rcnn_test_cfg = self.test_cfg.rcnn
aug_bboxes = []
aug_scores = []
for x, img_meta in zip(self.extract_feats(imgs), img_metas):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
scale_factor, flip)
# "ms" in variable names means multi-stage
ms_scores = []
rois = bbox2roi([proposals])
for i in range(self.num_stages):
bbox_roi_extractor = self.bbox_roi_extractor[i]
bbox_head = self.bbox_head[i]
bbox_feats = bbox_roi_extractor(
x[:len(bbox_roi_extractor.featmap_strides)], rois)
if self.with_shared_head:
bbox_feats = self.shared_head(bbox_feats)
cls_score, bbox_pred = bbox_head(bbox_feats)
ms_scores.append(cls_score)
if i < self.num_stages - 1:
bbox_label = cls_score.argmax(dim=1)
rois = bbox_head.regress_by_class(rois, bbox_label,
bbox_pred, img_meta[0])
cls_score = sum(ms_scores) / float(len(ms_scores))
bboxes, scores = self.bbox_head[-1].get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False,
cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms,
rcnn_test_cfg.max_per_img)
bbox_result = bbox2result(det_bboxes, det_labels,
self.bbox_head[-1].num_classes)
if self.with_mask:
if det_bboxes.shape[0] == 0:
segm_result = [[]
for _ in range(self.mask_head[-1].num_classes -
1)]
else:
aug_masks = []
aug_img_metas = []
for x, img_meta in zip(self.extract_feats(imgs), img_metas):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip)
mask_rois = bbox2roi([_bboxes])
for i in range(self.num_stages):
mask_feats = self.mask_roi_extractor[i](
x[:len(self.mask_roi_extractor[i].featmap_strides
)], mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
mask_pred = self.mask_head[i](mask_feats)
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
aug_img_metas.append(img_meta)
merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
self.test_cfg.rcnn)
ori_shape = img_metas[0][0]['ori_shape']
segm_result = self.mask_head[-1].get_seg_masks(
merged_masks,
det_bboxes,
det_labels,
rcnn_test_cfg,
ori_shape,
scale_factor=1.0,
rescale=False)
return bbox_result, segm_result
else:
return bbox_result
def show_result(self, data, result, **kwargs): def show_result(self, data, result, **kwargs):
if self.with_mask: if self.with_mask:
......
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from mmdet.core import (bbox2result, bbox2roi, build_assigner, build_sampler, from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
merge_aug_masks) build_sampler, merge_aug_bboxes, merge_aug_masks,
multiclass_nms)
from .. import builder from .. import builder
from ..registry import DETECTORS from ..registry import DETECTORS
from .cascade_rcnn import CascadeRCNN from .cascade_rcnn import CascadeRCNN
...@@ -431,5 +432,124 @@ class HybridTaskCascade(CascadeRCNN): ...@@ -431,5 +432,124 @@ class HybridTaskCascade(CascadeRCNN):
return results return results
def aug_test(self, img, img_meta, proposals=None, rescale=False): def aug_test(self, imgs, img_metas, proposals=None, rescale=False):
raise NotImplementedError """Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
if self.with_semantic:
semantic_feats = [
self.semantic_head(feat)[1]
for feat in self.extract_feats(imgs)
]
else:
semantic_feats = [None] * len(img_metas)
# recompute feats to save memory
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
rcnn_test_cfg = self.test_cfg.rcnn
aug_bboxes = []
aug_scores = []
for x, img_meta, semantic in zip(
self.extract_feats(imgs), img_metas, semantic_feats):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
scale_factor, flip)
# "ms" in variable names means multi-stage
ms_scores = []
rois = bbox2roi([proposals])
for i in range(self.num_stages):
bbox_head = self.bbox_head[i]
cls_score, bbox_pred = self._bbox_forward_test(
i, x, rois, semantic_feat=semantic)
ms_scores.append(cls_score)
if i < self.num_stages - 1:
bbox_label = cls_score.argmax(dim=1)
rois = bbox_head.regress_by_class(rois, bbox_label,
bbox_pred, img_meta[0])
cls_score = sum(ms_scores) / float(len(ms_scores))
bboxes, scores = self.bbox_head[-1].get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False,
cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms,
rcnn_test_cfg.max_per_img)
bbox_result = bbox2result(det_bboxes, det_labels,
self.bbox_head[-1].num_classes)
if self.with_mask:
if det_bboxes.shape[0] == 0:
segm_result = [[]
for _ in range(self.mask_head[-1].num_classes -
1)]
else:
aug_masks = []
aug_img_metas = []
for x, img_meta, semantic in zip(
self.extract_feats(imgs), img_metas, semantic_feats):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor[-1](
x[:len(self.mask_roi_extractor[-1].featmap_strides)],
mask_rois)
if self.with_semantic:
semantic_feat = semantic
mask_semantic_feat = self.semantic_roi_extractor(
[semantic_feat], mask_rois)
if mask_semantic_feat.shape[-2:] != mask_feats.shape[
-2:]:
mask_semantic_feat = F.adaptive_avg_pool2d(
mask_semantic_feat, mask_feats.shape[-2:])
mask_feats += mask_semantic_feat
last_feat = None
for i in range(self.num_stages):
mask_head = self.mask_head[i]
if self.mask_info_flow:
mask_pred, last_feat = mask_head(
mask_feats, last_feat)
else:
mask_pred = mask_head(mask_feats)
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
aug_img_metas.append(img_meta)
merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
self.test_cfg.rcnn)
ori_shape = img_metas[0][0]['ori_shape']
segm_result = self.mask_head[-1].get_seg_masks(
merged_masks,
det_bboxes,
det_labels,
rcnn_test_cfg,
ori_shape,
scale_factor=1.0,
rescale=False)
return bbox_result, segm_result
else:
return bbox_result
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment