Skip to content
Snippets Groups Projects
Commit 1f3e2734 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by Kai Chen
Browse files

Support to view AP for each class (#1549)

* also support to view ap

* change string format

* eval class_wise in coco_eval

* reformat

* class_wise API from detectron

* reformat

* change code source

* reformat, use terminaltable
parent 1fe3e7df
No related branches found
No related tags found
No related merge requests found
import itertools
import mmcv
import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from terminaltables import AsciiTable
from .recall import eval_recalls
def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000)):
def coco_eval(result_files,
result_types,
coco,
max_dets=(100, 300, 1000),
classwise=False):
for res_type in result_types:
assert res_type in [
'proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'
......@@ -43,6 +50,36 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000)):
cocoEval.accumulate()
cocoEval.summarize()
if classwise:
# Compute per-category AP
# from https://github.com/facebookresearch/detectron2/blob/03064eb5bafe4a3e5750cc7a16672daf5afe8435/detectron2/evaluation/coco_evaluation.py#L259-L283 # noqa
precisions = cocoEval.eval['precision']
catIds = coco.getCatIds()
# precision has dims (iou, recall, cls, area range, max dets)
assert len(catIds) == precisions.shape[2]
results_per_category = []
for idx, catId in enumerate(catIds):
# area range index 0: all area ranges
# max dets index -1: typically 100 per image
nm = coco.loadCats(catId)[0]
precision = precisions[:, :, idx, 0, -1]
precision = precision[precision > -1]
ap = np.mean(precision) if precision.size else float('nan')
results_per_category.append(
('{}'.format(nm['name']),
'{:0.3f}'.format(float(ap * 100))))
N_COLS = min(6, len(results_per_category) * 2)
results_flatten = list(itertools.chain(*results_per_category))
headers = ['category', 'AP'] * (N_COLS // 2)
results_2d = itertools.zip_longest(
*[results_flatten[i::N_COLS] for i in range(N_COLS)])
table_data = [headers]
table_data += [result for result in results_2d]
table = AsciiTable(table_data)
print(table.table)
def fast_eval_recall(results,
coco,
......
......@@ -20,8 +20,10 @@ def main():
nargs='+',
default=[100, 300, 1000],
help='proposal numbers, only used for recall evaluation')
parser.add_argument(
'--classwise', action='store_true', help='whether eval class wise ap')
args = parser.parse_args()
coco_eval(args.result, args.types, args.ann, args.max_dets)
coco_eval(args.result, args.types, args.ann, args.max_dets, args.classwise)
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment