Forked from
nikhil_rayaprolu / food-round2
85 commits behind the upstream repository.
-
Jon Crall authored
* Fix simple_test for base and fast_rcnn detector * Add test of detector forward methods * Revert forward_test change and add docs indicating expected inputs for forward * fix linter
Jon Crall authored* Fix simple_test for base and fast_rcnn detector * Add test of detector forward methods * Revert forward_test change and add docs indicating expected inputs for forward * fix linter
fast_rcnn.py 2.27 KiB
from ..registry import DETECTORS
from .two_stage import TwoStageDetector
@DETECTORS.register_module
class FastRCNN(TwoStageDetector):
def __init__(self,
backbone,
bbox_roi_extractor,
bbox_head,
train_cfg,
test_cfg,
neck=None,
shared_head=None,
mask_roi_extractor=None,
mask_head=None,
pretrained=None):
super(FastRCNN, self).__init__(
backbone=backbone,
neck=neck,
shared_head=shared_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
mask_roi_extractor=mask_roi_extractor,
mask_head=mask_head,
pretrained=pretrained)
def forward_test(self, imgs, img_metas, proposals, **kwargs):
"""
Args:
imgs (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_meta (List[List[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
proposals (List[List[Tensor | None]]): predefiend proposals for
each test-time augmentation and each item.
"""
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(imgs), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
imgs_per_gpu = imgs[0].size(0)
assert imgs_per_gpu == 1
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], proposals[0],
**kwargs)
else:
return self.aug_test(imgs, img_metas, proposals, **kwargs)