diff --git a/mmdet/core/eval/__init__.py b/mmdet/core/eval/__init__.py index c46d860d4b1c027f47ab75090bd13d0332d88efb..b5df6595a0e7e1421e7e0bab322f1ad9d66041c9 100644 --- a/mmdet/core/eval/__init__.py +++ b/mmdet/core/eval/__init__.py @@ -1,14 +1,17 @@ from .class_names import (voc_classes, imagenet_det_classes, imagenet_vid_classes, coco_classes, dataset_aliases, get_classes) -from .coco_utils import coco_eval +from .coco_utils import coco_eval, results2json +from .eval_hooks import DistEvalHook, DistEvalRecallHook, CocoDistEvalmAPHook from .mean_ap import average_precision, eval_map, print_map_summary from .recall import (eval_recalls, print_recall_summary, plot_num_recall, plot_iou_recall) __all__ = [ 'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes', - 'coco_classes', 'dataset_aliases', 'get_classes', 'average_precision', - 'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary', - 'plot_num_recall', 'plot_iou_recall', 'coco_eval' + 'coco_classes', 'dataset_aliases', 'get_classes', 'coco_eval', + 'results2json', 'DistEvalHook', 'DistEvalRecallHook', + 'CocoDistEvalmAPHook', 'average_precision', 'eval_map', + 'print_map_summary', 'eval_recalls', 'print_recall_summary', + 'plot_num_recall', 'plot_iou_recall' ] diff --git a/mmdet/core/eval/coco_utils.py b/mmdet/core/eval/coco_utils.py index cff6f678e4fd60cc1c863dcd1ae32e4a3e4db2ab..719e70a75e099859d0d05a483df3b957bf523d4f 100644 --- a/mmdet/core/eval/coco_utils.py +++ b/mmdet/core/eval/coco_utils.py @@ -1,4 +1,5 @@ import mmcv +import numpy as np from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval @@ -24,3 +25,77 @@ def coco_eval(result_file, result_types, coco, max_dets=(100, 300, 1000)): cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize() + + +def xyxy2xywh(bbox): + _bbox = bbox.tolist() + return [ + _bbox[0], + _bbox[1], + _bbox[2] - _bbox[0] + 1, + _bbox[3] - _bbox[1] + 1, + ] + + +def proposal2json(dataset, results): + json_results = [] + for idx in range(len(dataset)): + img_id = dataset.img_ids[idx] + bboxes = results[idx] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = 1 + json_results.append(data) + return json_results + + +def det2json(dataset, results): + json_results = [] + for idx in range(len(dataset)): + img_id = dataset.img_ids[idx] + result = results[idx] + for label in range(len(result)): + bboxes = result[label] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = dataset.cat_ids[label] + json_results.append(data) + return json_results + + +def segm2json(dataset, results): + json_results = [] + for idx in range(len(dataset)): + img_id = dataset.img_ids[idx] + det, seg = results[idx] + for label in range(len(det)): + bboxes = det[label] + segms = seg[label] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = dataset.cat_ids[label] + segms[i]['counts'] = segms[i]['counts'].decode() + data['segmentation'] = segms[i] + json_results.append(data) + return json_results + + +def results2json(dataset, results, out_file): + if isinstance(results[0], list): + json_results = det2json(dataset, results) + elif isinstance(results[0], tuple): + json_results = segm2json(dataset, results) + elif isinstance(results[0], np.ndarray): + json_results = proposal2json(dataset, results) + else: + raise TypeError('invalid type of results') + mmcv.dump(json_results, out_file) diff --git a/mmdet/core/eval/eval_hooks.py b/mmdet/core/eval/eval_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..2393449bc68582bef7615163d3f24b568860d714 --- /dev/null +++ b/mmdet/core/eval/eval_hooks.py @@ -0,0 +1,168 @@ +import os +import os.path as osp +import shutil +import time + +import mmcv +import numpy as np +import torch +from mmcv.torchpack import Hook, obj_from_dict +from pycocotools.cocoeval import COCOeval +from torch.utils.data import Dataset + +from .coco_utils import results2json +from .recall import eval_recalls +from ..parallel import scatter +from mmdet import datasets +from mmdet.datasets.loader import collate + + +class DistEvalHook(Hook): + + def __init__(self, dataset, interval=1): + if isinstance(dataset, Dataset): + self.dataset = dataset + elif isinstance(dataset, dict): + self.dataset = obj_from_dict(dataset, datasets, + {'test_mode': True}) + else: + raise TypeError( + 'dataset must be a Dataset object or a dict, not {}'.format( + type(dataset))) + self.interval = interval + self.lock_dir = None + + def _barrier(self, rank, world_size): + """Due to some issues with `torch.distributed.barrier()`, we have to + implement this ugly barrier function. + """ + if rank == 0: + for i in range(1, world_size): + tmp = osp.join(self.lock_dir, '{}.pkl'.format(i)) + while not (osp.exists(tmp)): + time.sleep(1) + for i in range(1, world_size): + tmp = osp.join(self.lock_dir, '{}.pkl'.format(i)) + os.remove(tmp) + else: + tmp = osp.join(self.lock_dir, '{}.pkl'.format(rank)) + mmcv.dump([], tmp) + while osp.exists(tmp): + time.sleep(1) + + def before_run(self, runner): + self.lock_dir = osp.join(runner.work_dir, '.lock_map_hook') + if runner.rank == 0: + if osp.exists(self.lock_dir): + shutil.rmtree(self.lock_dir) + mmcv.mkdir_or_exist(self.lock_dir) + + def after_train_epoch(self, runner): + if not self.every_n_epochs(runner, self.interval): + return + runner.model.eval() + results = [None for _ in range(len(self.dataset))] + prog_bar = mmcv.ProgressBar(len(self.dataset)) + for idx in range(runner.rank, len(self.dataset), runner.world_size): + data = self.dataset[idx] + data_gpu = scatter( + collate([data], samples_per_gpu=1), + [torch.cuda.current_device()])[0] + + # compute output + with torch.no_grad(): + result = runner.model( + **data_gpu, return_loss=False, rescale=True) + results[idx] = result + + batch_size = runner.world_size + for _ in range(batch_size): + prog_bar.update() + + if runner.rank == 0: + print('\n') + self._barrier(runner.rank, runner.world_size) + for i in range(1, runner.world_size): + tmp_file = osp.join(runner.work_dir, 'temp_{}.pkl'.format(i)) + tmp_results = mmcv.load(tmp_file) + for idx in range(i, len(results), runner.world_size): + results[idx] = tmp_results[idx] + os.remove(tmp_file) + self.evaluate(runner, results) + else: + tmp_file = osp.join(runner.work_dir, + 'temp_{}.pkl'.format(runner.rank)) + mmcv.dump(results, tmp_file) + self._barrier(runner.rank, runner.world_size) + self._barrier(runner.rank, runner.world_size) + + def evaluate(self): + raise NotImplementedError + + +class DistEvalRecallHook(DistEvalHook): + + def __init__(self, + dataset, + proposal_nums=(100, 300, 1000), + iou_thrs=np.arange(0.5, 0.96, 0.05)): + super(DistEvalRecallHook, self).__init__(dataset) + self.proposal_nums = np.array(proposal_nums, dtype=np.int32) + self.iou_thrs = np.array(iou_thrs, dtype=np.float32) + + def evaluate(self, runner, results): + # the official coco evaluation is too slow, here we use our own + # implementation instead, which may get slightly different results + gt_bboxes = [] + for i in range(len(self.dataset)): + img_id = self.dataset.img_ids[i] + ann_ids = self.dataset.coco.getAnnIds(imgIds=img_id) + ann_info = self.dataset.coco.loadAnns(ann_ids) + if len(ann_info) == 0: + gt_bboxes.append(np.zeros((0, 4))) + continue + bboxes = [] + for ann in ann_info: + if ann.get('ignore', False) or ann['iscrowd']: + continue + x1, y1, w, h = ann['bbox'] + bboxes.append([x1, y1, x1 + w - 1, y1 + h - 1]) + bboxes = np.array(bboxes, dtype=np.float32) + if bboxes.shape[0] == 0: + bboxes = np.zeros((0, 4)) + gt_bboxes.append(bboxes) + + recalls = eval_recalls( + gt_bboxes, + results, + self.proposal_nums, + self.iou_thrs, + print_summary=False) + ar = recalls.mean(axis=1) + for i, num in enumerate(self.proposal_nums): + runner.log_buffer.output['AR@{}'.format(num)] = ar[i] + runner.log_buffer.ready = True + + +class CocoDistEvalmAPHook(DistEvalHook): + + def evaluate(self, runner, results): + tmp_file = osp.join(runner.work_dir, 'temp_0.json') + results2json(self.dataset, results, tmp_file) + + res_types = ['bbox', + 'segm'] if runner.model.module.with_mask else ['bbox'] + cocoGt = self.dataset.coco + cocoDt = cocoGt.loadRes(tmp_file) + imgIds = cocoGt.getImgIds() + for res_type in res_types: + iou_type = res_type + cocoEval = COCOeval(cocoGt, cocoDt, iou_type) + cocoEval.params.imgIds = imgIds + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + field = '{}_mAP'.format(res_type) + runner.log_buffer.output[field] = cocoEval.stats[0] + runner.log_buffer.ready = True + os.remove(tmp_file) diff --git a/mmdet/core/utils/__init__.py b/mmdet/core/utils/__init__.py index 30c9c9e5c83797ddf7ea84d088ad29cb8e4b18cc..e04da6a9a5c82c02c924ecd7c763aaf8b1c02d4e 100644 --- a/mmdet/core/utils/__init__.py +++ b/mmdet/core/utils/__init__.py @@ -1,12 +1,9 @@ from .dist_utils import (init_dist, reduce_grads, DistOptimizerHook, DistSamplerSeedHook) -from .hooks import (EmptyCacheHook, DistEvalHook, DistEvalRecallHook, - CocoDistEvalmAPHook) -from .misc import tensor2imgs, unmap, results2json, multi_apply +from .hooks import EmptyCacheHook +from .misc import tensor2imgs, unmap, multi_apply __all__ = [ 'init_dist', 'reduce_grads', 'DistOptimizerHook', 'DistSamplerSeedHook', - 'EmptyCacheHook', 'DistEvalHook', 'DistEvalRecallHook', - 'CocoDistEvalmAPHook', 'tensor2imgs', 'unmap', 'results2json', - 'multi_apply' + 'EmptyCacheHook', 'tensor2imgs', 'unmap', 'multi_apply' ] diff --git a/mmdet/core/utils/dist_utils.py b/mmdet/core/utils/dist_utils.py index 4bc986ca73fc6faacb12c6fdbc20f020d6bdb56f..2a5d7659df78055fc068224d789a5d391f73d8e0 100644 --- a/mmdet/core/utils/dist_utils.py +++ b/mmdet/core/utils/dist_utils.py @@ -39,7 +39,7 @@ def _init_dist_slurm(backend, **kwargs): # modified from https://github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py#L9 -def coalesce_all_reduce(tensors): +def all_reduce_coalesced(tensors): buckets = OrderedDict() for tensor in tensors: tp = tensor.type() @@ -64,7 +64,7 @@ def reduce_grads(model, coalesce=True): if param.requires_grad and param.grad is not None ] if coalesce: - coalesce_all_reduce(grads) + all_reduce_coalesced(grads) else: for tensor in grads: dist.all_reduce(tensor) diff --git a/mmdet/core/utils/hooks.py b/mmdet/core/utils/hooks.py index 9772d4d64f1f5860a7646c23c950b620a052ae20..72eb3438efac1a3d2c8114914c087f0b563272d9 100644 --- a/mmdet/core/utils/hooks.py +++ b/mmdet/core/utils/hooks.py @@ -1,17 +1,5 @@ -import os -import os.path as osp -import shutil -import time - -import mmcv -import numpy as np import torch from mmcv.torchpack import Hook -from pycocotools.cocoeval import COCOeval - -from ..eval import eval_recalls -from ..parallel import scatter -from mmdet.datasets.loader import collate class EmptyCacheHook(Hook): @@ -21,220 +9,3 @@ class EmptyCacheHook(Hook): def after_epoch(self, runner): torch.cuda.empty_cache() - - -class DistEvalHook(Hook): - - def __init__(self, dataset, interval=1): - self.dataset = dataset - self.interval = interval - self.lock_dir = None - - def _barrier(self, rank, world_size): - """Due to some issues with `torch.distributed.barrier()`, we have to - implement this ugly barrier function. - """ - if rank == 0: - for i in range(1, world_size): - tmp = osp.join(self.lock_dir, '{}.pkl'.format(i)) - while not (osp.exists(tmp)): - time.sleep(1) - for i in range(1, world_size): - tmp = osp.join(self.lock_dir, '{}.pkl'.format(i)) - os.remove(tmp) - else: - tmp = osp.join(self.lock_dir, '{}.pkl'.format(rank)) - mmcv.dump([], tmp) - while osp.exists(tmp): - time.sleep(1) - - def before_run(self, runner): - self.lock_dir = osp.join(runner.work_dir, '.lock_map_hook') - if runner.rank == 0: - if osp.exists(self.lock_dir): - shutil.rmtree(self.lock_dir) - mmcv.mkdir_or_exist(self.lock_dir) - - def after_train_epoch(self, runner): - if not self.every_n_epochs(runner, self.interval): - return - runner.model.eval() - results = [None for _ in range(len(self.dataset))] - prog_bar = mmcv.ProgressBar(len(self.dataset)) - for idx in range(runner.rank, len(self.dataset), runner.world_size): - data = self.dataset[idx] - device_id = torch.cuda.current_device() - imgs_data = tuple( - scatter(collate([data], samples_per_gpu=1), [device_id])[0]) - - # compute output - with torch.no_grad(): - result = runner.model( - *imgs_data, - return_loss=False, - return_bboxes=True, - rescale=True) - results[idx] = result - - batch_size = runner.world_size - for _ in range(batch_size): - prog_bar.update() - - if runner.rank == 0: - print('\n') - self._barrier(runner.rank, runner.world_size) - for i in range(1, runner.world_size): - tmp_file = osp.join(runner.work_dir, 'temp_{}.pkl'.format(i)) - tmp_results = mmcv.load(tmp_file) - for idx in range(i, len(results), runner.world_size): - results[idx] = tmp_results[idx] - os.remove(tmp_file) - self.evaluate(runner, results) - else: - tmp_file = osp.join(runner.work_dir, - 'temp_{}.pkl'.format(runner.rank)) - mmcv.dump(results, tmp_file) - self._barrier(runner.rank, runner.world_size) - self._barrier(runner.rank, runner.world_size) - - def evaluate(self): - raise NotImplementedError - - -class CocoEvalMixin(object): - - def _xyxy2xywh(self, bbox): - _bbox = bbox.tolist() - return [ - _bbox[0], - _bbox[1], - _bbox[2] - _bbox[0] + 1, - _bbox[3] - _bbox[1] + 1, - ] - - def det2json(self, dataset, results): - json_results = [] - for idx in range(len(dataset)): - img_id = dataset.img_ids[idx] - result = results[idx] - for label in range(len(result)): - bboxes = result[label] - for i in range(bboxes.shape[0]): - data = dict() - data['image_id'] = img_id - data['bbox'] = self._xyxy2xywh(bboxes[i]) - data['score'] = float(bboxes[i][4]) - data['category_id'] = dataset.cat_ids[label] - json_results.append(data) - return json_results - - def segm2json(self, dataset, results): - json_results = [] - for idx in range(len(dataset)): - img_id = dataset.img_ids[idx] - det, seg = results[idx] - for label in range(len(det)): - bboxes = det[label] - segms = seg[label] - for i in range(bboxes.shape[0]): - data = dict() - data['image_id'] = img_id - data['bbox'] = self._xyxy2xywh(bboxes[i]) - data['score'] = float(bboxes[i][4]) - data['category_id'] = dataset.cat_ids[label] - segms[i]['counts'] = segms[i]['counts'].decode() - data['segmentation'] = segms[i] - json_results.append(data) - return json_results - - def proposal2json(self, dataset, results): - json_results = [] - for idx in range(len(dataset)): - img_id = dataset.img_ids[idx] - bboxes = results[idx] - for i in range(bboxes.shape[0]): - data = dict() - data['image_id'] = img_id - data['bbox'] = self._xyxy2xywh(bboxes[i]) - data['score'] = float(bboxes[i][4]) - data['category_id'] = 1 - json_results.append(data) - return json_results - - def results2json(self, dataset, results, out_file): - if isinstance(results[0], list): - json_results = self.det2json(dataset, results) - elif isinstance(results[0], tuple): - json_results = self.segm2json(dataset, results) - elif isinstance(results[0], np.ndarray): - json_results = self.proposal2json(dataset, results) - else: - raise TypeError('invalid type of results') - mmcv.dump(json_results, out_file, file_format='json') - - -class DistEvalRecallHook(DistEvalHook): - - def __init__(self, - dataset, - proposal_nums=(100, 300, 1000), - iou_thrs=np.arange(0.5, 0.96, 0.05)): - super(DistEvalRecallHook, self).__init__(dataset) - self.proposal_nums = np.array(proposal_nums, dtype=np.int32) - self.iou_thrs = np.array(iou_thrs, dtype=np.float32) - - def evaluate(self, runner, results): - # official coco evaluation is too slow, here we use our own - # implementation, which may get slightly different results - gt_bboxes = [] - for i in range(len(self.dataset)): - img_id = self.dataset.img_ids[i] - ann_ids = self.dataset.coco.getAnnIds(imgIds=img_id) - ann_info = self.dataset.coco.loadAnns(ann_ids) - if len(ann_info) == 0: - gt_bboxes.append(np.zeros((0, 4))) - continue - bboxes = [] - for ann in ann_info: - if ann.get('ignore', False) or ann['iscrowd']: - continue - x1, y1, w, h = ann['bbox'] - bboxes.append([x1, y1, x1 + w - 1, y1 + h - 1]) - bboxes = np.array(bboxes, dtype=np.float32) - if bboxes.shape[0] == 0: - bboxes = np.zeros((0, 4)) - gt_bboxes.append(bboxes) - - recalls = eval_recalls( - gt_bboxes, - results, - self.proposal_nums, - self.iou_thrs, - print_summary=False) - ar = recalls.mean(axis=1) - for i, num in enumerate(self.proposal_nums): - runner.log_buffer.output['AR@{}'.format(num)] = ar[i] - runner.log_buffer.ready = True - - -class CocoDistEvalmAPHook(DistEvalHook, CocoEvalMixin): - - def evaluate(self, runner, results): - tmp_file = osp.join(runner.work_dir, 'temp_0.json') - self.results2json(self.dataset, results, tmp_file) - - res_types = ['bbox', 'segm'] if runner.model.with_mask else ['bbox'] - cocoGt = self.dataset.coco - cocoDt = cocoGt.loadRes(tmp_file) - imgIds = cocoGt.getImgIds() - for res_type in res_types: - iou_type = res_type - cocoEval = COCOeval(cocoGt, cocoDt, iou_type) - cocoEval.params.imgIds = imgIds - cocoEval.evaluate() - cocoEval.accumulate() - cocoEval.summarize() - field = '{}_mAP'.format(res_type) - runner.log_buffer.output[field] = cocoEval.stats[0] - runner.log_buffer.ready = True - os.remove(tmp_file) diff --git a/mmdet/core/utils/misc.py b/mmdet/core/utils/misc.py index d34ff94302c2a5215681a3e81e38ca9aee8070df..fd8211ef68de37a1631135c9f19823653c0176f1 100644 --- a/mmdet/core/utils/misc.py +++ b/mmdet/core/utils/misc.py @@ -4,6 +4,7 @@ import mmcv import numpy as np from six.moves import map, zip + def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True): num_imgs = tensor.size(0) mean = np.array(mean, dtype=np.float32) @@ -33,77 +34,3 @@ def unmap(data, count, inds, fill=0): ret = data.new_full(new_size, fill) ret[inds, :] = data return ret - - -def xyxy2xywh(bbox): - _bbox = bbox.tolist() - return [ - _bbox[0], - _bbox[1], - _bbox[2] - _bbox[0] + 1, - _bbox[3] - _bbox[1] + 1, - ] - - -def proposal2json(dataset, results): - json_results = [] - for idx in range(len(dataset)): - img_id = dataset.img_ids[idx] - bboxes = results[idx] - for i in range(bboxes.shape[0]): - data = dict() - data['image_id'] = img_id - data['bbox'] = xyxy2xywh(bboxes[i]) - data['score'] = float(bboxes[i][4]) - data['category_id'] = 1 - json_results.append(data) - return json_results - - -def det2json(dataset, results): - json_results = [] - for idx in range(len(dataset)): - img_id = dataset.img_ids[idx] - result = results[idx] - for label in range(len(result)): - bboxes = result[label] - for i in range(bboxes.shape[0]): - data = dict() - data['image_id'] = img_id - data['bbox'] = xyxy2xywh(bboxes[i]) - data['score'] = float(bboxes[i][4]) - data['category_id'] = dataset.cat_ids[label] - json_results.append(data) - return json_results - - -def segm2json(dataset, results): - json_results = [] - for idx in range(len(dataset)): - img_id = dataset.img_ids[idx] - det, seg = results[idx] - for label in range(len(det)): - bboxes = det[label] - segms = seg[label] - for i in range(bboxes.shape[0]): - data = dict() - data['image_id'] = img_id - data['bbox'] = xyxy2xywh(bboxes[i]) - data['score'] = float(bboxes[i][4]) - data['category_id'] = dataset.cat_ids[label] - segms[i]['counts'] = segms[i]['counts'].decode() - data['segmentation'] = segms[i] - json_results.append(data) - return json_results - - -def results2json(dataset, results, out_file): - if isinstance(results[0], list): - json_results = det2json(dataset, results) - elif isinstance(results[0], tuple): - json_results = segm2json(dataset, results) - elif isinstance(results[0], np.ndarray): - json_results = proposal2json(dataset, results) - else: - raise TypeError('invalid type of results') - mmcv.dump(json_results, out_file) diff --git a/tools/configs/r50_fpn_frcnn_1x.py b/tools/configs/r50_fpn_frcnn_1x.py index 156b8b2aa4e2ad7f70d8e1a44682128076e9e34b..23903e084e204812571945bbe83fd626d6bc5922 100644 --- a/tools/configs/r50_fpn_frcnn_1x.py +++ b/tools/configs/r50_fpn_frcnn_1x.py @@ -93,16 +93,26 @@ data = dict( flip_ratio=0.5, with_mask=False, with_crowd=True, - with_label=True, - test_mode=False), - test=dict( + with_label=True), + val=dict( type=dataset_type, ann_file=data_root + 'annotations/instances_val2017.json', img_prefix=data_root + 'val2017/', img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, flip_ratio=0, + with_mask=False, + with_crowd=True, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), img_norm_cfg=img_norm_cfg, size_divisor=32, + flip_ratio=0, with_mask=False, with_label=False, test_mode=True)) @@ -128,7 +138,7 @@ log_config = dict( # runtime settings total_epochs = 12 device_ids = range(8) -dist_params = dict(backend='nccl', port='29500') +dist_params = dict(backend='gloo') log_level = 'INFO' work_dir = './work_dirs/fpn_faster_rcnn_r50_1x' load_from = None diff --git a/tools/configs/r50_fpn_maskrcnn_1x.py b/tools/configs/r50_fpn_maskrcnn_1x.py index 5697bca4a58cf7343b348371d56e8f631bda7a5a..41c2a1476dd438c685bc882a6d1478c00133e4e3 100644 --- a/tools/configs/r50_fpn_maskrcnn_1x.py +++ b/tools/configs/r50_fpn_maskrcnn_1x.py @@ -106,16 +106,26 @@ data = dict( flip_ratio=0.5, with_mask=True, with_crowd=True, - with_label=True, - test_mode=False), - test=dict( + with_label=True), + val=dict( type=dataset_type, ann_file=data_root + 'annotations/instances_val2017.json', img_prefix=data_root + 'val2017/', img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, flip_ratio=0, + with_mask=True, + with_crowd=True, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), img_norm_cfg=img_norm_cfg, size_divisor=32, + flip_ratio=0, with_mask=False, with_label=False, test_mode=True)) @@ -141,7 +151,7 @@ log_config = dict( # runtime settings total_epochs = 12 device_ids = range(8) -dist_params = dict(backend='nccl', port='29500') +dist_params = dict(backend='gloo') log_level = 'INFO' work_dir = './work_dirs/fpn_mask_rcnn_r50_1x' load_from = None diff --git a/tools/configs/r50_fpn_rpn_1x.py b/tools/configs/r50_fpn_rpn_1x.py index a00cab9de8013455f498d96cb12c9801aafcc343..1f14f72235b73e7724fad619e4041c453c6e60bc 100644 --- a/tools/configs/r50_fpn_rpn_1x.py +++ b/tools/configs/r50_fpn_rpn_1x.py @@ -65,16 +65,26 @@ data = dict( flip_ratio=0.5, with_mask=False, with_crowd=False, - with_label=False, - test_mode=False), - test=dict( + with_label=False), + val=dict( type=dataset_type, ann_file=data_root + 'annotations/instances_val2017.json', img_prefix=data_root + 'val2017/', img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, flip_ratio=0, + with_mask=False, + with_crowd=False, + with_label=False), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), img_norm_cfg=img_norm_cfg, size_divisor=32, + flip_ratio=0, with_mask=False, with_label=False, test_mode=True)) @@ -103,5 +113,5 @@ dist_params = dict(backend='gloo') log_level = 'INFO' work_dir = './work_dirs/fpn_rpn_r50_1x' load_from = None -resume_from = None -workflow = [('train', 1)] +resume_from = None +workflow = [('train', 1), ('val', 1)] diff --git a/tools/dist_train.sh b/tools/dist_train.sh index 8b79c6158daee7592766f8d9fa42770d7eb17c90..7bb903bd3dcef5c98cd5bcd93ce3e8cfb9de9882 100755 --- a/tools/dist_train.sh +++ b/tools/dist_train.sh @@ -2,4 +2,4 @@ PYTHON=${PYTHON:-"python"} -$PYTHON -m torch.distributed.launch --nproc_per_node=$2 train.py $1 --launcher pytorch \ No newline at end of file +$PYTHON -m torch.distributed.launch --nproc_per_node=$2 train.py $1 --launcher pytorch $3 diff --git a/tools/test.py b/tools/test.py index 0a43cdc316506ecf2b7addb8d11e3dc7dc30507b..4c87f4eeee74c3117f505c38eee13d94201380f5 100644 --- a/tools/test.py +++ b/tools/test.py @@ -59,7 +59,7 @@ def main(): cfg.model.pretrained = None cfg.data.test.test_mode = True - dataset = obj_from_dict(cfg.data.test, datasets) + dataset = obj_from_dict(cfg.data.test, datasets, dict(test_mode=True)) if args.gpus == 1: model = build_detector( cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) diff --git a/tools/train.py b/tools/train.py index fd47b1375622693e7e0fdd31c5e746df9dc5ff5a..6e1b9d91aaf53037e7fe246734ae8dfc7500c7df 100644 --- a/tools/train.py +++ b/tools/train.py @@ -9,9 +9,10 @@ from mmcv.torchpack import Runner, obj_from_dict from mmdet import datasets from mmdet.core import (init_dist, DistOptimizerHook, DistSamplerSeedHook, - MMDataParallel, MMDistributedDataParallel) + MMDataParallel, MMDistributedDataParallel, + DistEvalRecallHook, CocoDistEvalmAPHook) from mmdet.datasets.loader import build_dataloader -from mmdet.models import build_detector +from mmdet.models import build_detector, RPN def parse_losses(losses): @@ -109,6 +110,11 @@ def main(): cfg.checkpoint_config, cfg.log_config) if dist: runner.register_hook(DistSamplerSeedHook()) + # register eval hooks + if isinstance(model.module, RPN): + runner.register_hook(DistEvalRecallHook(cfg.data.val)) + elif cfg.data.val.type == 'CocoDataset': + runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) if cfg.resume_from: runner.resume(cfg.resume_from)