From b7894cbdcbe114e3e9efdd1a6a229419a552c807 Mon Sep 17 00:00:00 2001 From: valuefish <valuefish@gmail.com> Date: Fri, 29 Nov 2019 23:28:58 +0800 Subject: [PATCH] add multi nodes distributed test support (#1399) * add multi nodes distributed test support * fix bug in htc.py when keep_all_stages turn on * remove package imported but unused in test.py * reformat code in test.py * support both cpu & gpu for gathering * reformat * clean code, add doc * add docstring * reformat doc string --- tools/test.py | 69 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 5 deletions(-) diff --git a/tools/test.py b/tools/test.py index e3ff487..64dd733 100644 --- a/tools/test.py +++ b/tools/test.py @@ -1,6 +1,7 @@ import argparse import os import os.path as osp +import pickle import shutil import tempfile @@ -35,7 +36,25 @@ def single_gpu_test(model, data_loader, show=False): return results -def multi_gpu_test(model, data_loader, tmpdir=None): +def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): + """Test model with multiple gpus. + + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' + it encodes results to gpu tensors and use gpu communication for results + collection. On cpu mode it saves the results on different gpus to 'tmpdir' + and collects them by the rank 0 worker. + + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + + Returns: + list: The prediction results. + """ model.eval() results = [] dataset = data_loader.dataset @@ -53,12 +72,14 @@ def multi_gpu_test(model, data_loader, tmpdir=None): prog_bar.update() # collect results from all ranks - results = collect_results(results, len(dataset), tmpdir) - + if gpu_collect: + results = collect_results_gpu(results, len(dataset)) + else: + results = collect_results_cpu(results, len(dataset), tmpdir) return results -def collect_results(result_part, size, tmpdir=None): +def collect_results_cpu(result_part, size, tmpdir=None): rank, world_size = get_dist_info() # create a tmp dir if it is not specified if tmpdir is None: @@ -100,6 +121,39 @@ def collect_results(result_part, size, tmpdir=None): return ordered_results +def collect_results_gpu(result_part, size): + rank, world_size = get_dist_info() + # dump result part to tensor with pickle + part_tensor = torch.tensor( + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_list.append( + pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results + + def parse_args(): parser = argparse.ArgumentParser(description='MMDet test detector') parser.add_argument('config', help='test config file path') @@ -116,6 +170,10 @@ def parse_args(): choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'], help='eval types') parser.add_argument('--show', action='store_true', help='show results') + parser.add_argument( + '--gpu_collect', + action='store_true', + help='whether to use gpu to collect results') parser.add_argument('--tmpdir', help='tmp dir for writing some results') parser.add_argument( '--launcher', @@ -184,7 +242,8 @@ def main(): outputs = single_gpu_test(model, data_loader, args.show) else: model = MMDistributedDataParallel(model.cuda()) - outputs = multi_gpu_test(model, data_loader, args.tmpdir) + outputs = multi_gpu_test(model, data_loader, args.tmpdir, + args.gpu_collect) rank, _ = get_dist_info() if args.out and rank == 0: -- GitLab