diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index e869f32837c7b5def6e63dffc846e1229a3af059..27edf5a3e3b52f8970f9170d8d07669eeeb77666 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -14,19 +14,17 @@ and also some high-level apis for easier integration to other projects. - [x] multiple GPU testing - [x] visualize detection results -You can use the following command to test a dataset. +You can use the following commands to test a dataset. ```shell -python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--gpus ${GPU_NUM}] [--proc_per_gpu ${PROC_NUM}] [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show] -``` +# single-gpu testing +python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show] -Positional arguments: -- `CONFIG_FILE`: Path to the config file of the corresponding model. -- `CHECKPOINT_FILE`: Path to the checkpoint file. +# multi-gpu testing +./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] +``` Optional arguments: -- `GPU_NUM`: Number of GPUs used for testing. (default: 1) -- `PROC_NUM`: Number of processes on each GPU. (default: 1) - `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file. - `EVAL_METRICS`: Items to be evaluated on the results. Allowed values are: `proposal_fast`, `proposal`, `bbox`, `segm`, `keypoints`. - `--show`: If specified, detection results will be ploted on the images and shown in a new window. Only applicable for single GPU testing. @@ -51,12 +49,12 @@ python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \ --out results.pkl --eval bbox mask ``` -3. Test Mask R-CNN with 8 GPUs and 2 processes per GPU, and evaluate the bbox and mask AP. +3. Test Mask R-CNN with 8 GPUs, and evaluate the bbox and mask AP. ```shell -python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \ +./tools/dist_test.sh configs/mask_rcnn_r50_fpn_1x.py \ checkpoints/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth \ - --gpus 8 --proc_per_gpu 2 --out results.pkl --eval bbox mask + 8 --out results.pkl --eval bbox mask ``` ### High-level APIs for testing images. diff --git a/mmdet/datasets/loader/build_loader.py b/mmdet/datasets/loader/build_loader.py index 3e10e2399d5a20c30500e9a28db662eb78211c84..248f8f8900dd5217d4128f6058146f968f6c65bc 100644 --- a/mmdet/datasets/loader/build_loader.py +++ b/mmdet/datasets/loader/build_loader.py @@ -4,7 +4,7 @@ from mmcv.runner import get_dist_info from mmcv.parallel import collate from torch.utils.data import DataLoader -from .sampler import GroupSampler, DistributedGroupSampler +from .sampler import GroupSampler, DistributedGroupSampler, DistributedSampler # https://github.com/pytorch/pytorch/issues/973 import resource @@ -18,27 +18,31 @@ def build_dataloader(dataset, num_gpus=1, dist=True, **kwargs): + shuffle = kwargs.get('shuffle', True) if dist: rank, world_size = get_dist_info() - sampler = DistributedGroupSampler(dataset, imgs_per_gpu, world_size, - rank) + if shuffle: + sampler = DistributedGroupSampler(dataset, imgs_per_gpu, + world_size, rank) + else: + sampler = DistributedSampler(dataset, + world_size, + rank, + shuffle=False) batch_size = imgs_per_gpu num_workers = workers_per_gpu else: - if not kwargs.get('shuffle', True): - sampler = None - else: - sampler = GroupSampler(dataset, imgs_per_gpu) + sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None batch_size = num_gpus * imgs_per_gpu num_workers = num_gpus * workers_per_gpu - data_loader = DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - num_workers=num_workers, - collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu), - pin_memory=False, - **kwargs) + data_loader = DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial(collate, + samples_per_gpu=imgs_per_gpu), + pin_memory=False, + **kwargs) return data_loader diff --git a/mmdet/datasets/loader/sampler.py b/mmdet/datasets/loader/sampler.py index 5c060cd926ea50d232d0f765b86933ca8fad0969..d2eef23bf780b97ee5839617f75b4b1fc79d1826 100644 --- a/mmdet/datasets/loader/sampler.py +++ b/mmdet/datasets/loader/sampler.py @@ -5,7 +5,34 @@ import torch import numpy as np from torch.distributed import get_world_size, get_rank -from torch.utils.data.sampler import Sampler +from torch.utils.data import Sampler +from torch.utils.data import DistributedSampler as _DistributedSampler + + +class DistributedSampler(_DistributedSampler): + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + super().__init__(dataset, num_replicas=num_replicas, rank=rank) + self.shuffle = shuffle + + def __iter__(self): + # deterministically shuffle based on epoch + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) class GroupSampler(Sampler): @@ -112,8 +139,8 @@ class DistributedGroupSampler(Sampler): indices = [ indices[j] for i in list( - torch.randperm( - len(indices) // self.samples_per_gpu, generator=g)) + torch.randperm(len(indices) // self.samples_per_gpu, + generator=g)) for j in range(i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu) ] diff --git a/tools/dist_test.sh b/tools/dist_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..5f6abf1a2d31292689a0fd88a5dbf98558b04098 --- /dev/null +++ b/tools/dist_test.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +PYTHON=${PYTHON:-"python"} + +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 + +$PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS \ + $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} diff --git a/tools/dist_train.sh b/tools/dist_train.sh index fa68297226b874596a54b9c819f03584008093e6..a6ed4858c6c237d5a4b7e09fab9c4979c710d500 100755 --- a/tools/dist_train.sh +++ b/tools/dist_train.sh @@ -2,4 +2,8 @@ PYTHON=${PYTHON:-"python"} -$PYTHON -m torch.distributed.launch --nproc_per_node=$2 $(dirname "$0")/train.py $1 --launcher pytorch ${@:3} +CONFIG=$1 +GPUS=$2 + +$PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS \ + $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} diff --git a/tools/slurm_test.sh b/tools/slurm_test.sh index be2ab9d6a5befcc673023989dfc32578b5612a04..8950bc81609495f8d86186a1e567cb63de1282ad 100755 --- a/tools/slurm_test.sh +++ b/tools/slurm_test.sh @@ -7,16 +7,17 @@ JOB_NAME=$2 CONFIG=$3 CHECKPOINT=$4 GPUS=${GPUS:-8} -CPUS_PER_TASK=${CPUS_PER_TASK:-32} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} PY_ARGS=${@:5} SRUN_ARGS=${SRUN_ARGS:-""} srun -p ${PARTITION} \ --job-name=${JOB_NAME} \ - --gres=gpu:${GPUS} \ - --ntasks=1 \ - --ntasks-per-node=1 \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ --cpus-per-task=${CPUS_PER_TASK} \ --kill-on-bad-exit=1 \ ${SRUN_ARGS} \ - python tools/test.py ${CONFIG} ${CHECKPOINT} --gpus ${GPUS} ${PY_ARGS} + python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} diff --git a/tools/test.py b/tools/test.py index 8aa23ea7ec8d1758fd06df6c47e765026dcd071e..f8c0e2d76f35db86d711ce446ad5e54fea6c4804 100644 --- a/tools/test.py +++ b/tools/test.py @@ -1,17 +1,21 @@ import argparse +import os.path as osp +import shutil +import tempfile -import torch import mmcv -from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict -from mmcv.parallel import scatter, collate, MMDataParallel +import torch +import torch.distributed as dist +from mmcv.runner import load_checkpoint, get_dist_info +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmdet import datasets +from mmdet.apis import init_dist from mmdet.core import results2json, coco_eval -from mmdet.datasets import build_dataloader -from mmdet.models import build_detector, detectors +from mmdet.datasets import build_dataloader, get_dataset +from mmdet.models import build_detector -def single_test(model, data_loader, show=False): +def single_gpu_test(model, data_loader, show=False): model.eval() results = [] dataset = data_loader.dataset @@ -22,7 +26,9 @@ def single_test(model, data_loader, show=False): results.append(result) if show: - model.module.show_result(data, result, dataset.img_norm_cfg, + model.module.show_result(data, + result, + dataset.img_norm_cfg, dataset=dataset.CLASSES) batch_size = data['img'][0].size(0) @@ -31,22 +37,76 @@ def single_test(model, data_loader, show=False): return results -def _data_func(data, device_id): - data = scatter(collate([data], samples_per_gpu=1), [device_id])[0] - return dict(return_loss=False, rescale=True, **data) +def multi_gpu_test(model, data_loader, tmpdir=None): + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + results.append(result) + + if rank == 0: + batch_size = data['img'][0].size(0) + for _ in range(batch_size * world_size): + prog_bar.update() + + # collect results from all ranks + results = collect_results(results, len(dataset), tmpdir) + + return results + + +def collect_results(result_part, size, tmpdir=None): + rank, world_size = get_dist_info() + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') + if rank == 0: + tmpdir = tempfile.mkdtemp() + tmpdir = torch.tensor(bytearray(tmpdir.encode()), + dtype=torch.uint8, + device='cuda') + dir_tensor[:len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + mmcv.mkdir_or_exist(tmpdir) + # dump the part result to the dir + mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank))) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i)) + part_list.append(mmcv.load(part_file)) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results def parse_args(): parser = argparse.ArgumentParser(description='MMDet test detector') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') - parser.add_argument( - '--gpus', default=1, type=int, help='GPU number used for testing') - parser.add_argument( - '--proc_per_gpu', - default=1, - type=int, - help='Number of processes per GPU') parser.add_argument('--out', help='output result file') parser.add_argument( '--eval', @@ -55,6 +115,12 @@ def parse_args(): choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'], help='eval types') parser.add_argument('--show', action='store_true', help='show results') + parser.add_argument('--tmpdir', help='tmp dir for writing some results') + parser.add_argument('--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() return args @@ -72,36 +138,36 @@ def main(): cfg.model.pretrained = None cfg.data.test.test_mode = True - dataset = obj_from_dict(cfg.data.test, datasets, dict(test_mode=True)) - if args.gpus == 1: - model = build_detector( - cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) - load_checkpoint(model, args.checkpoint) + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + + # build the dataloader + # TODO: support multiple images per gpu (only minor changes are needed) + dataset = get_dataset(cfg.data.test) + data_loader = build_dataloader(dataset, + imgs_per_gpu=1, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False) + + # build the model and load checkpoint + model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) + load_checkpoint(model, args.checkpoint, map_location='cpu') + + if not distributed: model = MMDataParallel(model, device_ids=[0]) - - data_loader = build_dataloader( - dataset, - imgs_per_gpu=1, - workers_per_gpu=cfg.data.workers_per_gpu, - num_gpus=1, - dist=False, - shuffle=False) - outputs = single_test(model, data_loader, args.show) + outputs = single_gpu_test(model, data_loader, args.show) else: - model_args = cfg.model.copy() - model_args.update(train_cfg=None, test_cfg=cfg.test_cfg) - model_type = getattr(detectors, model_args.pop('type')) - outputs = parallel_test( - model_type, - model_args, - args.checkpoint, - dataset, - _data_func, - range(args.gpus), - workers_per_gpu=args.proc_per_gpu) - - if args.out: - print('writing results to {}'.format(args.out)) + model = MMDistributedDataParallel(model.cuda()) + outputs = multi_gpu_test(model, data_loader, args.tmpdir) + + rank, _ = get_dist_info() + if args.out and rank == 0: + print('\nwriting results to {}'.format(args.out)) mmcv.dump(outputs, args.out) eval_types = args.eval if eval_types: