diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py index 6cf87ffe40b6d7fc13e6278807e6a6e8a6af7b27..1a074eec18578ae2abb60e3bc36797712a43ad0e 100644 --- a/mmdet/core/evaluation/eval_hooks.py +++ b/mmdet/core/evaluation/eval_hooks.py @@ -76,25 +76,9 @@ class DistEvalHook(Hook): class DistEvalmAPHook(DistEvalHook): def evaluate(self, runner, results): - gt_bboxes = [] - gt_labels = [] - gt_ignore = [] - for i in range(len(self.dataset)): - ann = self.dataset.get_ann_info(i) - bboxes = ann['bboxes'] - labels = ann['labels'] - if 'bboxes_ignore' in ann: - ignore = np.concatenate([ - np.zeros(bboxes.shape[0], dtype=np.bool), - np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool) - ]) - gt_ignore.append(ignore) - bboxes = np.vstack([bboxes, ann['bboxes_ignore']]) - labels = np.concatenate([labels, ann['labels_ignore']]) - gt_bboxes.append(bboxes) - gt_labels.append(labels) - if not gt_ignore: - gt_ignore = None + annotations = [ + self.dataset.get_ann_info(i) for i in range(len(self.dataset)) + ] # If the dataset is VOC2007, then use 11 points mAP evaluation. if hasattr(self.dataset, 'year') and self.dataset.year == 2007: ds_name = 'voc07' @@ -102,13 +86,11 @@ class DistEvalmAPHook(DistEvalHook): ds_name = self.dataset.CLASSES mean_ap, eval_results = eval_map( results, - gt_bboxes, - gt_labels, - gt_ignore=gt_ignore, + annotations, scale_ranges=None, iou_thr=0.5, dataset=ds_name, - print_summary=True) + logger=runner.logger) runner.log_buffer.output['mAP'] = mean_ap runner.log_buffer.ready = True diff --git a/mmdet/core/evaluation/mean_ap.py b/mmdet/core/evaluation/mean_ap.py index 8f708987252b03b9a8cf21db157c6d83366fba49..eb877ee916dd2bbd5c7dbdb594ffa8d3a808ce0b 100644 --- a/mmdet/core/evaluation/mean_ap.py +++ b/mmdet/core/evaluation/mean_ap.py @@ -1,3 +1,6 @@ +import logging +from multiprocessing import Pool + import mmcv import numpy as np from terminaltables import AsciiTable @@ -55,21 +58,33 @@ def average_precision(recalls, precisions, mode='area'): def tpfp_imagenet(det_bboxes, gt_bboxes, - gt_ignore, - default_iou_thr, + gt_bboxes_ignore=None, + default_iou_thr=0.5, area_ranges=None): """Check if detected bboxes are true positive or false positive. Args: - det_bbox (ndarray): the detected bbox - gt_bboxes (ndarray): ground truth bboxes of this image - gt_ignore (ndarray): indicate if gts are ignored for evaluation or not - default_iou_thr (float): the iou thresholds for medium and large bboxes - area_ranges (list or None): gt bbox area ranges + det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5). + gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4). + gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, + of shape (k, 4). Default: None + default_iou_thr (float): IoU threshold to be considered as matched for + medium and large bboxes (small ones have special rules). + Default: 0.5. + area_ranges (list[tuple] | None): Range of bbox areas to be evaluated, + in the format [(min1, max1), (min2, max2), ...]. Default: None. Returns: - tuple: two arrays (tp, fp) whose elements are 0 and 1 + tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of + each array is (num_scales, m). """ + # an indicator of ignored gts + gt_ignore_inds = np.concatenate( + (np.zeros(gt_bboxes.shape[0], dtype=np.bool), + np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool))) + # stack gt_bboxes and gt_bboxes_ignore for convenience + gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) + num_dets = det_bboxes.shape[0] num_gts = gt_bboxes.shape[0] if area_ranges is None: @@ -99,7 +114,7 @@ def tpfp_imagenet(det_bboxes, gt_covered = np.zeros(num_gts, dtype=bool) # if no area range is specified, gt_area_ignore is all False if min_area is None: - gt_area_ignore = np.zeros_like(gt_ignore, dtype=bool) + gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) else: gt_areas = gt_w * gt_h gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) @@ -122,7 +137,8 @@ def tpfp_imagenet(det_bboxes, # 4. it matches no gt but is beyond area range, tp = 0, fp = 0 if matched_gt >= 0: gt_covered[matched_gt] = 1 - if not (gt_ignore[matched_gt] or gt_area_ignore[matched_gt]): + if not (gt_ignore_inds[matched_gt] + or gt_area_ignore[matched_gt]): tp[k, i] = 1 elif min_area is None: fp[k, i] = 1 @@ -134,18 +150,34 @@ def tpfp_imagenet(det_bboxes, return tp, fp -def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None): +def tpfp_default(det_bboxes, + gt_bboxes, + gt_bboxes_ignore=None, + iou_thr=0.5, + area_ranges=None): """Check if detected bboxes are true positive or false positive. Args: - det_bbox (ndarray): the detected bbox - gt_bboxes (ndarray): ground truth bboxes of this image - gt_ignore (ndarray): indicate if gts are ignored for evaluation or not - iou_thr (float): the iou thresholds + det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5). + gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4). + gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, + of shape (k, 4). Default: None + iou_thr (float): IoU threshold to be considered as matched. + Default: 0.5. + area_ranges (list[tuple] | None): Range of bbox areas to be evaluated, + in the format [(min1, max1), (min2, max2), ...]. Default: None. Returns: - tuple: (tp, fp), two arrays whose elements are 0 and 1 + tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of + each array is (num_scales, m). """ + # an indicator of ignored gts + gt_ignore_inds = np.concatenate( + (np.zeros(gt_bboxes.shape[0], dtype=np.bool), + np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool))) + # stack gt_bboxes and gt_bboxes_ignore for convenience + gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) + num_dets = det_bboxes.shape[0] num_gts = gt_bboxes.shape[0] if area_ranges is None: @@ -155,6 +187,7 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None): # a certain scale tp = np.zeros((num_scales, num_dets), dtype=np.float32) fp = np.zeros((num_scales, num_dets), dtype=np.float32) + # if there is no gt bboxes in this image, then all det bboxes # within area range are false positives if gt_bboxes.shape[0] == 0: @@ -166,15 +199,19 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None): for i, (min_area, max_area) in enumerate(area_ranges): fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 return tp, fp + ious = bbox_overlaps(det_bboxes, gt_bboxes) + # for each det, the max iou with all gts ious_max = ious.max(axis=1) + # for each det, which gt overlaps most with it ious_argmax = ious.argmax(axis=1) + # sort all dets in descending order by scores sort_inds = np.argsort(-det_bboxes[:, -1]) for k, (min_area, max_area) in enumerate(area_ranges): gt_covered = np.zeros(num_gts, dtype=bool) # if no area range is specified, gt_area_ignore is all False if min_area is None: - gt_area_ignore = np.zeros_like(gt_ignore, dtype=bool) + gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) else: gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * ( gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1) @@ -182,7 +219,8 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None): for i in sort_inds: if ious_max[i] >= iou_thr: matched_gt = ious_argmax[i] - if not (gt_ignore[matched_gt] or gt_area_ignore[matched_gt]): + if not (gt_ignore_inds[matched_gt] + or gt_area_ignore[matched_gt]): if not gt_covered[matched_gt]: gt_covered[matched_gt] = True tp[k, i] = 1 @@ -199,88 +237,109 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None): return tp, fp -def get_cls_results(det_results, gt_bboxes, gt_labels, gt_ignore, class_id): - """Get det results and gt information of a certain class.""" - cls_dets = [det[class_id] - for det in det_results] # det bboxes of this class - cls_gts = [] # gt bboxes of this class - cls_gt_ignore = [] - for j in range(len(gt_bboxes)): - gt_bbox = gt_bboxes[j] - cls_inds = (gt_labels[j] == class_id + 1) - cls_gt = gt_bbox[cls_inds, :] if gt_bbox.shape[0] > 0 else gt_bbox - cls_gts.append(cls_gt) - if gt_ignore is None: - cls_gt_ignore.append(np.zeros(cls_gt.shape[0], dtype=np.int32)) +def get_cls_results(det_results, annotations, class_id): + """Get det results and gt information of a certain class. + + Args: + det_results (list[list]): Same as `eval_map()`. + annotations (list[dict]): Same as `eval_map()`. + + Returns: + tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes + """ + cls_dets = [img_res[class_id] for img_res in det_results] + cls_gts = [] + cls_gts_ignore = [] + for ann in annotations: + gt_inds = ann['labels'] == (class_id + 1) + cls_gts.append(ann['bboxes'][gt_inds, :]) + + if ann.get('labels_ignore', None) is not None: + ignore_inds = ann['labels_ignore'] == (class_id + 1) + cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :]) else: - cls_gt_ignore.append(gt_ignore[j][cls_inds]) - return cls_dets, cls_gts, cls_gt_ignore + cls_gts_ignore.append(np.array((0, 4), dtype=np.float32)) + + return cls_dets, cls_gts, cls_gts_ignore def eval_map(det_results, - gt_bboxes, - gt_labels, - gt_ignore=None, + annotations, scale_ranges=None, iou_thr=0.5, dataset=None, - print_summary=True): + logger='default', + nproc=4): """Evaluate mAP of a dataset. Args: - det_results (list): a list of list, [[cls1_det, cls2_det, ...], ...] - gt_bboxes (list): ground truth bboxes of each image, a list of K*4 - array. - gt_labels (list): ground truth labels of each image, a list of K array - gt_ignore (list): gt ignore indicators of each image, a list of K array - scale_ranges (list, optional): [(min1, max1), (min2, max2), ...] - iou_thr (float): IoU threshold - dataset (None or str or list): dataset name or dataset classes, there - are minor differences in metrics for different datsets, e.g. - "voc07", "imagenet_det", etc. - print_summary (bool): whether to print the mAP summary + det_results (list[list]): [[cls1_det, cls2_det, ...], ...]. + The outer list indicates images, and the inner list indicates + per-class detected bboxes. + annotations (list[dict]): Ground truth annotations where each item of + the list indicates an image. Keys of annotations are: + - "bboxes": numpy array of shape (n, 4) + - "labels": numpy array of shape (n, ) + - "bboxes_ignore" (optional): numpy array of shape (k, 4) + - "labels_ignore" (optional): numpy array of shape (k, ) + scale_ranges (list[tuple] | None): Range of scales to be evaluated, + in the format [(min1, max1), (min2, max2), ...]. A range of + (32, 64) means the area range between (32**2, 64**2). + Default: None. + iou_thr (float): IoU threshold to be considered as matched. + Default: 0.5. + dataset (list[str] | str | None): Dataset name or dataset classes, + there are minor differences in metrics for different datsets, e.g. + "voc07", "imagenet_det", etc. Default: None. + logger (logging.Logger | 'print' | None): The way to print the mAP + summary. If a Logger is specified, then the summary will be logged + with `logger.info()`; if set to "print", then it will be simply + printed to stdout; if set to None, then no information will be + printed. Default: 'print'. + nproc (int): Processes used for computing TP and FP. + Default: 4. Returns: tuple: (mAP, [dict, dict, ...]) """ - assert len(det_results) == len(gt_bboxes) == len(gt_labels) - if gt_ignore is not None: - assert len(gt_ignore) == len(gt_labels) - for i in range(len(gt_ignore)): - assert len(gt_labels[i]) == len(gt_ignore[i]) + assert len(det_results) == len(annotations) + + num_imgs = len(det_results) + num_scales = len(scale_ranges) if scale_ranges is not None else 1 + num_classes = len(det_results[0]) # positive class num area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges] if scale_ranges is not None else None) - num_scales = len(scale_ranges) if scale_ranges is not None else 1 + + pool = Pool(nproc) eval_results = [] - num_classes = len(det_results[0]) # positive class num - gt_labels = [ - label if label.ndim == 1 else label[:, 0] for label in gt_labels - ] for i in range(num_classes): # get gt and det bboxes of this class - cls_dets, cls_gts, cls_gt_ignore = get_cls_results( - det_results, gt_bboxes, gt_labels, gt_ignore, i) - # calculate tp and fp for each image - tpfp_func = ( - tpfp_imagenet if dataset in ['det', 'vid'] else tpfp_default) - tpfp = [ - tpfp_func(cls_dets[j], cls_gts[j], cls_gt_ignore[j], iou_thr, - area_ranges) for j in range(len(cls_dets)) - ] + cls_dets, cls_gts, cls_gts_ignore = get_cls_results( + det_results, annotations, i) + # choose proper function according to datasets to compute tp and fp + if dataset in ['det', 'vid']: + tpfp_func = tpfp_imagenet + else: + tpfp_func = tpfp_default + # compute tp and fp for each image with multiple processes + tpfp = pool.starmap( + tpfp_func, + zip(cls_dets, cls_gts, cls_gts_ignore, + [iou_thr for _ in range(num_imgs)], + [area_ranges for _ in range(num_imgs)])) tp, fp = tuple(zip(*tpfp)) - # calculate gt number of each scale, gts ignored or beyond scale - # are not counted + # calculate gt number of each scale + # ignored gts or gts beyond the specific scale are not counted num_gts = np.zeros(num_scales, dtype=int) for j, bbox in enumerate(cls_gts): if area_ranges is None: - num_gts[0] += np.sum(np.logical_not(cls_gt_ignore[j])) + num_gts[0] += bbox.shape[0] else: gt_areas = (bbox[:, 2] - bbox[:, 0] + 1) * ( bbox[:, 3] - bbox[:, 1] + 1) for k, (min_area, max_area) in enumerate(area_ranges): - num_gts[k] += np.sum( - np.logical_not(cls_gt_ignore[j]) - & (gt_areas >= min_area) & (gt_areas < max_area)) + num_gts[k] += np.sum((gt_areas >= min_area) + & (gt_areas < max_area)) # sort all det bboxes by score, also sort tp and fp cls_dets = np.vstack(cls_dets) num_dets = cls_dets.shape[0] @@ -324,37 +383,60 @@ def eval_map(det_results, if cls_result['num_gts'] > 0: aps.append(cls_result['ap']) mean_ap = np.array(aps).mean().item() if aps else 0.0 - if print_summary: - print_map_summary(mean_ap, eval_results, dataset, area_ranges) + if logger is not None: + print_map_summary( + mean_ap, eval_results, dataset, area_ranges, logger=logger) return mean_ap, eval_results -def print_map_summary(mean_ap, results, dataset=None, ranges=None): +def print_map_summary(mean_ap, + results, + dataset=None, + scale_ranges=None, + logger=None): """Print mAP and results of each class. + A table will be printed to show the gts/dets/recall/AP of each class and + the mAP. + Args: - mean_ap(float): calculated from `eval_map` - results(list): calculated from `eval_map` - dataset(None or str or list): dataset name or dataset classes. - ranges(list or Tuple): ranges of areas + mean_ap (float): Calculated from `eval_map()`. + results (list[dict]): Calculated from `eval_map()`. + dataset (list[str] | str | None): Dataset name or dataset classes. + scale_ranges (list[tuple] | None): Range of scales to be evaluated. + logger (logging.Logger | 'print' | None): The way to print the mAP + summary. If a Logger is specified, then the summary will be logged + with `logger.info()`; if set to "print", then it will be simply + printed to stdout; if set to None, then no information will be + printed. Default: 'print'. """ - num_scales = len(results[0]['ap']) if isinstance(results[0]['ap'], - np.ndarray) else 1 - if ranges is not None: - assert len(ranges) == num_scales + + def _print(content): + if logger == 'print': + print(content) + elif isinstance(logger, logging.Logger): + logger.info(content) + + if isinstance(results[0]['ap'], np.ndarray): + num_scales = len(results[0]['ap']) + else: + num_scales = 1 + + if scale_ranges is not None: + assert len(scale_ranges) == num_scales + + assert logger is None or logger == 'print' or isinstance( + logger, logging.Logger) num_classes = len(results) recalls = np.zeros((num_scales, num_classes), dtype=np.float32) - precisions = np.zeros((num_scales, num_classes), dtype=np.float32) aps = np.zeros((num_scales, num_classes), dtype=np.float32) num_gts = np.zeros((num_scales, num_classes), dtype=int) for i, cls_result in enumerate(results): if cls_result['recall'].size > 0: recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1] - precisions[:, i] = np.array( - cls_result['precision'], ndmin=2)[:, -1] aps[:, i] = cls_result['ap'] num_gts[:, i] = cls_result['num_gts'] @@ -367,19 +449,19 @@ def print_map_summary(mean_ap, results, dataset=None, ranges=None): if not isinstance(mean_ap, list): mean_ap = [mean_ap] - header = ['class', 'gts', 'dets', 'recall', 'precision', 'ap'] + + header = ['class', 'gts', 'dets', 'recall', 'ap'] for i in range(num_scales): - if ranges is not None: - print("Area range ", ranges[i]) + if scale_ranges is not None: + _print('Scale range ', scale_ranges[i]) table_data = [header] for j in range(num_classes): row_data = [ label_names[j], num_gts[i, j], results[j]['num_dets'], - '{:.3f}'.format(recalls[i, j]), - '{:.3f}'.format(precisions[i, j]), '{:.3f}'.format(aps[i, j]) + '{:.3f}'.format(recalls[i, j]), '{:.3f}'.format(aps[i, j]) ] table_data.append(row_data) - table_data.append(['mAP', '', '', '', '', '{:.3f}'.format(mean_ap[i])]) + table_data.append(['mAP', '', '', '', '{:.3f}'.format(mean_ap[i])]) table = AsciiTable(table_data) table.inner_footing_row_border = True - print(table.table) + _print('\n' + table.table) diff --git a/tools/test_robustness.py b/tools/test_robustness.py index fb58deb952f2581cfb5a63254809f2cad1ccb215..2271f4c06daa825f06c3629e531d164de8e7cd03 100644 --- a/tools/test_robustness.py +++ b/tools/test_robustness.py @@ -76,41 +76,21 @@ def coco_eval_with_return(result_files, def voc_eval_with_return(result_file, dataset, iou_thr=0.5, - print_summary=True, + logger='print', only_ap=True): det_results = mmcv.load(result_file) - gt_bboxes = [] - gt_labels = [] - gt_ignore = [] - for i in range(len(dataset)): - ann = dataset.get_ann_info(i) - bboxes = ann['bboxes'] - labels = ann['labels'] - if 'bboxes_ignore' in ann: - ignore = np.concatenate([ - np.zeros(bboxes.shape[0], dtype=np.bool), - np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool) - ]) - gt_ignore.append(ignore) - bboxes = np.vstack([bboxes, ann['bboxes_ignore']]) - labels = np.concatenate([labels, ann['labels_ignore']]) - gt_bboxes.append(bboxes) - gt_labels.append(labels) - if not gt_ignore: - gt_ignore = gt_ignore + annotations = [dataset.get_ann_info(i) for i in range(len(dataset))] if hasattr(dataset, 'year') and dataset.year == 2007: dataset_name = 'voc07' else: dataset_name = dataset.CLASSES mean_ap, eval_results = eval_map( det_results, - gt_bboxes, - gt_labels, - gt_ignore=gt_ignore, + annotations, scale_ranges=None, iou_thr=iou_thr, dataset=dataset_name, - print_summary=print_summary) + logger=logger) if only_ap: eval_results = [{ @@ -411,10 +391,11 @@ def main(): if eval_type == 'bbox': test_dataset = mmcv.runner.obj_from_dict( cfg.data.test, datasets) + logger = 'print' if args.summaries else None mean_ap, eval_results = \ voc_eval_with_return( args.out, test_dataset, - args.iou_thr, args.summaries) + args.iou_thr, logger) aggregated_results[corruption][ corruption_severity] = eval_results else: diff --git a/tools/voc_eval.py b/tools/voc_eval.py index a86b13e856ff39418027dbd39f3287966ae0e171..be0bde6db991e60576cc233456abb8692133c8f9 100644 --- a/tools/voc_eval.py +++ b/tools/voc_eval.py @@ -1,46 +1,26 @@ from argparse import ArgumentParser import mmcv -import numpy as np from mmdet import datasets from mmdet.core import eval_map -def voc_eval(result_file, dataset, iou_thr=0.5): +def voc_eval(result_file, dataset, iou_thr=0.5, nproc=4): det_results = mmcv.load(result_file) - gt_bboxes = [] - gt_labels = [] - gt_ignore = [] - for i in range(len(dataset)): - ann = dataset.get_ann_info(i) - bboxes = ann['bboxes'] - labels = ann['labels'] - if 'bboxes_ignore' in ann: - ignore = np.concatenate([ - np.zeros(bboxes.shape[0], dtype=np.bool), - np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool) - ]) - gt_ignore.append(ignore) - bboxes = np.vstack([bboxes, ann['bboxes_ignore']]) - labels = np.concatenate([labels, ann['labels_ignore']]) - gt_bboxes.append(bboxes) - gt_labels.append(labels) - if not gt_ignore: - gt_ignore = None + annotations = [dataset.get_ann_info(i) for i in range(len(dataset))] if hasattr(dataset, 'year') and dataset.year == 2007: dataset_name = 'voc07' else: dataset_name = dataset.CLASSES eval_map( det_results, - gt_bboxes, - gt_labels, - gt_ignore=gt_ignore, + annotations, scale_ranges=None, iou_thr=iou_thr, dataset=dataset_name, - print_summary=True) + logger='print', + nproc=nproc) def main(): @@ -52,10 +32,15 @@ def main(): type=float, default=0.5, help='IoU threshold for evaluation') + parser.add_argument( + '--nproc', + type=int, + default=4, + help='Processes to be used for computing mAP') args = parser.parse_args() cfg = mmcv.Config.fromfile(args.config) test_dataset = mmcv.runner.obj_from_dict(cfg.data.test, datasets) - voc_eval(args.result, test_dataset, args.iou_thr) + voc_eval(args.result, test_dataset, args.iou_thr, args.nproc) if __name__ == '__main__':