diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py
index 6cf87ffe40b6d7fc13e6278807e6a6e8a6af7b27..1a074eec18578ae2abb60e3bc36797712a43ad0e 100644
--- a/mmdet/core/evaluation/eval_hooks.py
+++ b/mmdet/core/evaluation/eval_hooks.py
@@ -76,25 +76,9 @@ class DistEvalHook(Hook):
 class DistEvalmAPHook(DistEvalHook):
 
     def evaluate(self, runner, results):
-        gt_bboxes = []
-        gt_labels = []
-        gt_ignore = []
-        for i in range(len(self.dataset)):
-            ann = self.dataset.get_ann_info(i)
-            bboxes = ann['bboxes']
-            labels = ann['labels']
-            if 'bboxes_ignore' in ann:
-                ignore = np.concatenate([
-                    np.zeros(bboxes.shape[0], dtype=np.bool),
-                    np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool)
-                ])
-                gt_ignore.append(ignore)
-                bboxes = np.vstack([bboxes, ann['bboxes_ignore']])
-                labels = np.concatenate([labels, ann['labels_ignore']])
-            gt_bboxes.append(bboxes)
-            gt_labels.append(labels)
-        if not gt_ignore:
-            gt_ignore = None
+        annotations = [
+            self.dataset.get_ann_info(i) for i in range(len(self.dataset))
+        ]
         # If the dataset is VOC2007, then use 11 points mAP evaluation.
         if hasattr(self.dataset, 'year') and self.dataset.year == 2007:
             ds_name = 'voc07'
@@ -102,13 +86,11 @@ class DistEvalmAPHook(DistEvalHook):
             ds_name = self.dataset.CLASSES
         mean_ap, eval_results = eval_map(
             results,
-            gt_bboxes,
-            gt_labels,
-            gt_ignore=gt_ignore,
+            annotations,
             scale_ranges=None,
             iou_thr=0.5,
             dataset=ds_name,
-            print_summary=True)
+            logger=runner.logger)
         runner.log_buffer.output['mAP'] = mean_ap
         runner.log_buffer.ready = True
 
diff --git a/mmdet/core/evaluation/mean_ap.py b/mmdet/core/evaluation/mean_ap.py
index 8f708987252b03b9a8cf21db157c6d83366fba49..eb877ee916dd2bbd5c7dbdb594ffa8d3a808ce0b 100644
--- a/mmdet/core/evaluation/mean_ap.py
+++ b/mmdet/core/evaluation/mean_ap.py
@@ -1,3 +1,6 @@
+import logging
+from multiprocessing import Pool
+
 import mmcv
 import numpy as np
 from terminaltables import AsciiTable
@@ -55,21 +58,33 @@ def average_precision(recalls, precisions, mode='area'):
 
 def tpfp_imagenet(det_bboxes,
                   gt_bboxes,
-                  gt_ignore,
-                  default_iou_thr,
+                  gt_bboxes_ignore=None,
+                  default_iou_thr=0.5,
                   area_ranges=None):
     """Check if detected bboxes are true positive or false positive.
 
     Args:
-        det_bbox (ndarray): the detected bbox
-        gt_bboxes (ndarray): ground truth bboxes of this image
-        gt_ignore (ndarray): indicate if gts are ignored for evaluation or not
-        default_iou_thr (float): the iou thresholds for medium and large bboxes
-        area_ranges (list or None): gt bbox area ranges
+        det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
+        gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
+        gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
+            of shape (k, 4). Default: None
+        default_iou_thr (float): IoU threshold to be considered as matched for
+            medium and large bboxes (small ones have special rules).
+            Default: 0.5.
+        area_ranges (list[tuple] | None): Range of bbox areas to be evaluated,
+            in the format [(min1, max1), (min2, max2), ...]. Default: None.
 
     Returns:
-        tuple: two arrays (tp, fp) whose elements are 0 and 1
+        tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
+            each array is (num_scales, m).
     """
+    # an indicator of ignored gts
+    gt_ignore_inds = np.concatenate(
+        (np.zeros(gt_bboxes.shape[0], dtype=np.bool),
+         np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool)))
+    # stack gt_bboxes and gt_bboxes_ignore for convenience
+    gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
+
     num_dets = det_bboxes.shape[0]
     num_gts = gt_bboxes.shape[0]
     if area_ranges is None:
@@ -99,7 +114,7 @@ def tpfp_imagenet(det_bboxes,
         gt_covered = np.zeros(num_gts, dtype=bool)
         # if no area range is specified, gt_area_ignore is all False
         if min_area is None:
-            gt_area_ignore = np.zeros_like(gt_ignore, dtype=bool)
+            gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
         else:
             gt_areas = gt_w * gt_h
             gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
@@ -122,7 +137,8 @@ def tpfp_imagenet(det_bboxes,
             # 4. it matches no gt but is beyond area range, tp = 0, fp = 0
             if matched_gt >= 0:
                 gt_covered[matched_gt] = 1
-                if not (gt_ignore[matched_gt] or gt_area_ignore[matched_gt]):
+                if not (gt_ignore_inds[matched_gt]
+                        or gt_area_ignore[matched_gt]):
                     tp[k, i] = 1
             elif min_area is None:
                 fp[k, i] = 1
@@ -134,18 +150,34 @@ def tpfp_imagenet(det_bboxes,
     return tp, fp
 
 
-def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None):
+def tpfp_default(det_bboxes,
+                 gt_bboxes,
+                 gt_bboxes_ignore=None,
+                 iou_thr=0.5,
+                 area_ranges=None):
     """Check if detected bboxes are true positive or false positive.
 
     Args:
-        det_bbox (ndarray): the detected bbox
-        gt_bboxes (ndarray): ground truth bboxes of this image
-        gt_ignore (ndarray): indicate if gts are ignored for evaluation or not
-        iou_thr (float): the iou thresholds
+        det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
+        gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
+        gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
+            of shape (k, 4). Default: None
+        iou_thr (float): IoU threshold to be considered as matched.
+            Default: 0.5.
+        area_ranges (list[tuple] | None): Range of bbox areas to be evaluated,
+            in the format [(min1, max1), (min2, max2), ...]. Default: None.
 
     Returns:
-        tuple: (tp, fp), two arrays whose elements are 0 and 1
+        tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
+            each array is (num_scales, m).
     """
+    # an indicator of ignored gts
+    gt_ignore_inds = np.concatenate(
+        (np.zeros(gt_bboxes.shape[0], dtype=np.bool),
+         np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool)))
+    # stack gt_bboxes and gt_bboxes_ignore for convenience
+    gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
+
     num_dets = det_bboxes.shape[0]
     num_gts = gt_bboxes.shape[0]
     if area_ranges is None:
@@ -155,6 +187,7 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None):
     # a certain scale
     tp = np.zeros((num_scales, num_dets), dtype=np.float32)
     fp = np.zeros((num_scales, num_dets), dtype=np.float32)
+
     # if there is no gt bboxes in this image, then all det bboxes
     # within area range are false positives
     if gt_bboxes.shape[0] == 0:
@@ -166,15 +199,19 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None):
             for i, (min_area, max_area) in enumerate(area_ranges):
                 fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
         return tp, fp
+
     ious = bbox_overlaps(det_bboxes, gt_bboxes)
+    # for each det, the max iou with all gts
     ious_max = ious.max(axis=1)
+    # for each det, which gt overlaps most with it
     ious_argmax = ious.argmax(axis=1)
+    # sort all dets in descending order by scores
     sort_inds = np.argsort(-det_bboxes[:, -1])
     for k, (min_area, max_area) in enumerate(area_ranges):
         gt_covered = np.zeros(num_gts, dtype=bool)
         # if no area range is specified, gt_area_ignore is all False
         if min_area is None:
-            gt_area_ignore = np.zeros_like(gt_ignore, dtype=bool)
+            gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
         else:
             gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * (
                 gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1)
@@ -182,7 +219,8 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None):
         for i in sort_inds:
             if ious_max[i] >= iou_thr:
                 matched_gt = ious_argmax[i]
-                if not (gt_ignore[matched_gt] or gt_area_ignore[matched_gt]):
+                if not (gt_ignore_inds[matched_gt]
+                        or gt_area_ignore[matched_gt]):
                     if not gt_covered[matched_gt]:
                         gt_covered[matched_gt] = True
                         tp[k, i] = 1
@@ -199,88 +237,109 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None):
     return tp, fp
 
 
-def get_cls_results(det_results, gt_bboxes, gt_labels, gt_ignore, class_id):
-    """Get det results and gt information of a certain class."""
-    cls_dets = [det[class_id]
-                for det in det_results]  # det bboxes of this class
-    cls_gts = []  # gt bboxes of this class
-    cls_gt_ignore = []
-    for j in range(len(gt_bboxes)):
-        gt_bbox = gt_bboxes[j]
-        cls_inds = (gt_labels[j] == class_id + 1)
-        cls_gt = gt_bbox[cls_inds, :] if gt_bbox.shape[0] > 0 else gt_bbox
-        cls_gts.append(cls_gt)
-        if gt_ignore is None:
-            cls_gt_ignore.append(np.zeros(cls_gt.shape[0], dtype=np.int32))
+def get_cls_results(det_results, annotations, class_id):
+    """Get det results and gt information of a certain class.
+
+    Args:
+        det_results (list[list]): Same as `eval_map()`.
+        annotations (list[dict]): Same as `eval_map()`.
+
+    Returns:
+        tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes
+    """
+    cls_dets = [img_res[class_id] for img_res in det_results]
+    cls_gts = []
+    cls_gts_ignore = []
+    for ann in annotations:
+        gt_inds = ann['labels'] == (class_id + 1)
+        cls_gts.append(ann['bboxes'][gt_inds, :])
+
+        if ann.get('labels_ignore', None) is not None:
+            ignore_inds = ann['labels_ignore'] == (class_id + 1)
+            cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :])
         else:
-            cls_gt_ignore.append(gt_ignore[j][cls_inds])
-    return cls_dets, cls_gts, cls_gt_ignore
+            cls_gts_ignore.append(np.array((0, 4), dtype=np.float32))
+
+    return cls_dets, cls_gts, cls_gts_ignore
 
 
 def eval_map(det_results,
-             gt_bboxes,
-             gt_labels,
-             gt_ignore=None,
+             annotations,
              scale_ranges=None,
              iou_thr=0.5,
              dataset=None,
-             print_summary=True):
+             logger='default',
+             nproc=4):
     """Evaluate mAP of a dataset.
 
     Args:
-        det_results (list): a list of list, [[cls1_det, cls2_det, ...], ...]
-        gt_bboxes (list): ground truth bboxes of each image, a list of K*4
-            array.
-        gt_labels (list): ground truth labels of each image, a list of K array
-        gt_ignore (list): gt ignore indicators of each image, a list of K array
-        scale_ranges (list, optional): [(min1, max1), (min2, max2), ...]
-        iou_thr (float): IoU threshold
-        dataset (None or str or list): dataset name or dataset classes, there
-            are minor differences in metrics for different datsets, e.g.
-            "voc07", "imagenet_det", etc.
-        print_summary (bool): whether to print the mAP summary
+        det_results (list[list]): [[cls1_det, cls2_det, ...], ...].
+            The outer list indicates images, and the inner list indicates
+            per-class detected bboxes.
+        annotations (list[dict]): Ground truth annotations where each item of
+            the list indicates an image. Keys of annotations are:
+                - "bboxes": numpy array of shape (n, 4)
+                - "labels": numpy array of shape (n, )
+                - "bboxes_ignore" (optional): numpy array of shape (k, 4)
+                - "labels_ignore" (optional): numpy array of shape (k, )
+        scale_ranges (list[tuple] | None): Range of scales to be evaluated,
+            in the format [(min1, max1), (min2, max2), ...]. A range of
+            (32, 64) means the area range between (32**2, 64**2).
+            Default: None.
+        iou_thr (float): IoU threshold to be considered as matched.
+            Default: 0.5.
+        dataset (list[str] | str | None): Dataset name or dataset classes,
+            there are minor differences in metrics for different datsets, e.g.
+            "voc07", "imagenet_det", etc. Default: None.
+        logger (logging.Logger | 'print' | None): The way to print the mAP
+            summary. If a Logger is specified, then the summary will be logged
+            with `logger.info()`; if set to "print", then it will be simply
+            printed to stdout; if set to None, then no information will be
+            printed. Default: 'print'.
+        nproc (int): Processes used for computing TP and FP.
+            Default: 4.
 
     Returns:
         tuple: (mAP, [dict, dict, ...])
     """
-    assert len(det_results) == len(gt_bboxes) == len(gt_labels)
-    if gt_ignore is not None:
-        assert len(gt_ignore) == len(gt_labels)
-        for i in range(len(gt_ignore)):
-            assert len(gt_labels[i]) == len(gt_ignore[i])
+    assert len(det_results) == len(annotations)
+
+    num_imgs = len(det_results)
+    num_scales = len(scale_ranges) if scale_ranges is not None else 1
+    num_classes = len(det_results[0])  # positive class num
     area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges]
                    if scale_ranges is not None else None)
-    num_scales = len(scale_ranges) if scale_ranges is not None else 1
+
+    pool = Pool(nproc)
     eval_results = []
-    num_classes = len(det_results[0])  # positive class num
-    gt_labels = [
-        label if label.ndim == 1 else label[:, 0] for label in gt_labels
-    ]
     for i in range(num_classes):
         # get gt and det bboxes of this class
-        cls_dets, cls_gts, cls_gt_ignore = get_cls_results(
-            det_results, gt_bboxes, gt_labels, gt_ignore, i)
-        # calculate tp and fp for each image
-        tpfp_func = (
-            tpfp_imagenet if dataset in ['det', 'vid'] else tpfp_default)
-        tpfp = [
-            tpfp_func(cls_dets[j], cls_gts[j], cls_gt_ignore[j], iou_thr,
-                      area_ranges) for j in range(len(cls_dets))
-        ]
+        cls_dets, cls_gts, cls_gts_ignore = get_cls_results(
+            det_results, annotations, i)
+        # choose proper function according to datasets to compute tp and fp
+        if dataset in ['det', 'vid']:
+            tpfp_func = tpfp_imagenet
+        else:
+            tpfp_func = tpfp_default
+        # compute tp and fp for each image with multiple processes
+        tpfp = pool.starmap(
+            tpfp_func,
+            zip(cls_dets, cls_gts, cls_gts_ignore,
+                [iou_thr for _ in range(num_imgs)],
+                [area_ranges for _ in range(num_imgs)]))
         tp, fp = tuple(zip(*tpfp))
-        # calculate gt number of each scale, gts ignored or beyond scale
-        # are not counted
+        # calculate gt number of each scale
+        # ignored gts or gts beyond the specific scale are not counted
         num_gts = np.zeros(num_scales, dtype=int)
         for j, bbox in enumerate(cls_gts):
             if area_ranges is None:
-                num_gts[0] += np.sum(np.logical_not(cls_gt_ignore[j]))
+                num_gts[0] += bbox.shape[0]
             else:
                 gt_areas = (bbox[:, 2] - bbox[:, 0] + 1) * (
                     bbox[:, 3] - bbox[:, 1] + 1)
                 for k, (min_area, max_area) in enumerate(area_ranges):
-                    num_gts[k] += np.sum(
-                        np.logical_not(cls_gt_ignore[j])
-                        & (gt_areas >= min_area) & (gt_areas < max_area))
+                    num_gts[k] += np.sum((gt_areas >= min_area)
+                                         & (gt_areas < max_area))
         # sort all det bboxes by score, also sort tp and fp
         cls_dets = np.vstack(cls_dets)
         num_dets = cls_dets.shape[0]
@@ -324,37 +383,60 @@ def eval_map(det_results,
             if cls_result['num_gts'] > 0:
                 aps.append(cls_result['ap'])
         mean_ap = np.array(aps).mean().item() if aps else 0.0
-    if print_summary:
-        print_map_summary(mean_ap, eval_results, dataset, area_ranges)
+    if logger is not None:
+        print_map_summary(
+            mean_ap, eval_results, dataset, area_ranges, logger=logger)
 
     return mean_ap, eval_results
 
 
-def print_map_summary(mean_ap, results, dataset=None, ranges=None):
+def print_map_summary(mean_ap,
+                      results,
+                      dataset=None,
+                      scale_ranges=None,
+                      logger=None):
     """Print mAP and results of each class.
 
+    A table will be printed to show the gts/dets/recall/AP of each class and
+    the mAP.
+
     Args:
-        mean_ap(float): calculated from `eval_map`
-        results(list): calculated from `eval_map`
-        dataset(None or str or list): dataset name or dataset classes.
-        ranges(list or Tuple): ranges of areas
+        mean_ap (float): Calculated from `eval_map()`.
+        results (list[dict]): Calculated from `eval_map()`.
+        dataset (list[str] | str | None): Dataset name or dataset classes.
+        scale_ranges (list[tuple] | None): Range of scales to be evaluated.
+        logger (logging.Logger | 'print' | None): The way to print the mAP
+            summary. If a Logger is specified, then the summary will be logged
+            with `logger.info()`; if set to "print", then it will be simply
+            printed to stdout; if set to None, then no information will be
+            printed. Default: 'print'.
     """
-    num_scales = len(results[0]['ap']) if isinstance(results[0]['ap'],
-                                                     np.ndarray) else 1
-    if ranges is not None:
-        assert len(ranges) == num_scales
+
+    def _print(content):
+        if logger == 'print':
+            print(content)
+        elif isinstance(logger, logging.Logger):
+            logger.info(content)
+
+    if isinstance(results[0]['ap'], np.ndarray):
+        num_scales = len(results[0]['ap'])
+    else:
+        num_scales = 1
+
+    if scale_ranges is not None:
+        assert len(scale_ranges) == num_scales
+
+    assert logger is None or logger == 'print' or isinstance(
+        logger, logging.Logger)
 
     num_classes = len(results)
 
     recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
-    precisions = np.zeros((num_scales, num_classes), dtype=np.float32)
     aps = np.zeros((num_scales, num_classes), dtype=np.float32)
     num_gts = np.zeros((num_scales, num_classes), dtype=int)
     for i, cls_result in enumerate(results):
         if cls_result['recall'].size > 0:
             recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1]
-            precisions[:, i] = np.array(
-                cls_result['precision'], ndmin=2)[:, -1]
         aps[:, i] = cls_result['ap']
         num_gts[:, i] = cls_result['num_gts']
 
@@ -367,19 +449,19 @@ def print_map_summary(mean_ap, results, dataset=None, ranges=None):
 
     if not isinstance(mean_ap, list):
         mean_ap = [mean_ap]
-    header = ['class', 'gts', 'dets', 'recall', 'precision', 'ap']
+
+    header = ['class', 'gts', 'dets', 'recall', 'ap']
     for i in range(num_scales):
-        if ranges is not None:
-            print("Area range ", ranges[i])
+        if scale_ranges is not None:
+            _print('Scale range ', scale_ranges[i])
         table_data = [header]
         for j in range(num_classes):
             row_data = [
                 label_names[j], num_gts[i, j], results[j]['num_dets'],
-                '{:.3f}'.format(recalls[i, j]),
-                '{:.3f}'.format(precisions[i, j]), '{:.3f}'.format(aps[i, j])
+                '{:.3f}'.format(recalls[i, j]), '{:.3f}'.format(aps[i, j])
             ]
             table_data.append(row_data)
-        table_data.append(['mAP', '', '', '', '', '{:.3f}'.format(mean_ap[i])])
+        table_data.append(['mAP', '', '', '', '{:.3f}'.format(mean_ap[i])])
         table = AsciiTable(table_data)
         table.inner_footing_row_border = True
-        print(table.table)
+        _print('\n' + table.table)
diff --git a/tools/test_robustness.py b/tools/test_robustness.py
index fb58deb952f2581cfb5a63254809f2cad1ccb215..2271f4c06daa825f06c3629e531d164de8e7cd03 100644
--- a/tools/test_robustness.py
+++ b/tools/test_robustness.py
@@ -76,41 +76,21 @@ def coco_eval_with_return(result_files,
 def voc_eval_with_return(result_file,
                          dataset,
                          iou_thr=0.5,
-                         print_summary=True,
+                         logger='print',
                          only_ap=True):
     det_results = mmcv.load(result_file)
-    gt_bboxes = []
-    gt_labels = []
-    gt_ignore = []
-    for i in range(len(dataset)):
-        ann = dataset.get_ann_info(i)
-        bboxes = ann['bboxes']
-        labels = ann['labels']
-        if 'bboxes_ignore' in ann:
-            ignore = np.concatenate([
-                np.zeros(bboxes.shape[0], dtype=np.bool),
-                np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool)
-            ])
-            gt_ignore.append(ignore)
-            bboxes = np.vstack([bboxes, ann['bboxes_ignore']])
-            labels = np.concatenate([labels, ann['labels_ignore']])
-        gt_bboxes.append(bboxes)
-        gt_labels.append(labels)
-    if not gt_ignore:
-        gt_ignore = gt_ignore
+    annotations = [dataset.get_ann_info(i) for i in range(len(dataset))]
     if hasattr(dataset, 'year') and dataset.year == 2007:
         dataset_name = 'voc07'
     else:
         dataset_name = dataset.CLASSES
     mean_ap, eval_results = eval_map(
         det_results,
-        gt_bboxes,
-        gt_labels,
-        gt_ignore=gt_ignore,
+        annotations,
         scale_ranges=None,
         iou_thr=iou_thr,
         dataset=dataset_name,
-        print_summary=print_summary)
+        logger=logger)
 
     if only_ap:
         eval_results = [{
@@ -411,10 +391,11 @@ def main():
                             if eval_type == 'bbox':
                                 test_dataset = mmcv.runner.obj_from_dict(
                                     cfg.data.test, datasets)
+                                logger = 'print' if args.summaries else None
                                 mean_ap, eval_results = \
                                     voc_eval_with_return(
                                         args.out, test_dataset,
-                                        args.iou_thr, args.summaries)
+                                        args.iou_thr, logger)
                                 aggregated_results[corruption][
                                     corruption_severity] = eval_results
                             else:
diff --git a/tools/voc_eval.py b/tools/voc_eval.py
index a86b13e856ff39418027dbd39f3287966ae0e171..be0bde6db991e60576cc233456abb8692133c8f9 100644
--- a/tools/voc_eval.py
+++ b/tools/voc_eval.py
@@ -1,46 +1,26 @@
 from argparse import ArgumentParser
 
 import mmcv
-import numpy as np
 
 from mmdet import datasets
 from mmdet.core import eval_map
 
 
-def voc_eval(result_file, dataset, iou_thr=0.5):
+def voc_eval(result_file, dataset, iou_thr=0.5, nproc=4):
     det_results = mmcv.load(result_file)
-    gt_bboxes = []
-    gt_labels = []
-    gt_ignore = []
-    for i in range(len(dataset)):
-        ann = dataset.get_ann_info(i)
-        bboxes = ann['bboxes']
-        labels = ann['labels']
-        if 'bboxes_ignore' in ann:
-            ignore = np.concatenate([
-                np.zeros(bboxes.shape[0], dtype=np.bool),
-                np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool)
-            ])
-            gt_ignore.append(ignore)
-            bboxes = np.vstack([bboxes, ann['bboxes_ignore']])
-            labels = np.concatenate([labels, ann['labels_ignore']])
-        gt_bboxes.append(bboxes)
-        gt_labels.append(labels)
-    if not gt_ignore:
-        gt_ignore = None
+    annotations = [dataset.get_ann_info(i) for i in range(len(dataset))]
     if hasattr(dataset, 'year') and dataset.year == 2007:
         dataset_name = 'voc07'
     else:
         dataset_name = dataset.CLASSES
     eval_map(
         det_results,
-        gt_bboxes,
-        gt_labels,
-        gt_ignore=gt_ignore,
+        annotations,
         scale_ranges=None,
         iou_thr=iou_thr,
         dataset=dataset_name,
-        print_summary=True)
+        logger='print',
+        nproc=nproc)
 
 
 def main():
@@ -52,10 +32,15 @@ def main():
         type=float,
         default=0.5,
         help='IoU threshold for evaluation')
+    parser.add_argument(
+        '--nproc',
+        type=int,
+        default=4,
+        help='Processes to be used for computing mAP')
     args = parser.parse_args()
     cfg = mmcv.Config.fromfile(args.config)
     test_dataset = mmcv.runner.obj_from_dict(cfg.data.test, datasets)
-    voc_eval(args.result, test_dataset, args.iou_thr)
+    voc_eval(args.result, test_dataset, args.iou_thr, args.nproc)
 
 
 if __name__ == '__main__':