From fd67644c369f5aebc38747f23f86323918e5575f Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Wed, 10 Apr 2019 22:13:05 -0700 Subject: [PATCH] use torch.distributed.barrier() instead of the self-implemented one --- mmdet/core/evaluation/eval_hooks.py | 39 +++-------------------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py index ba6adbb..140c1ed 100644 --- a/mmdet/core/evaluation/eval_hooks.py +++ b/mmdet/core/evaluation/eval_hooks.py @@ -1,11 +1,10 @@ import os import os.path as osp -import shutil -import time import mmcv import numpy as np import torch +import torch.distributed as dist from mmcv.runner import Hook, obj_from_dict from mmcv.parallel import scatter, collate from pycocotools.cocoeval import COCOeval @@ -29,36 +28,6 @@ class DistEvalHook(Hook): 'dataset must be a Dataset object or a dict, not {}'.format( type(dataset))) self.interval = interval - self.lock_dir = None - - def _barrier(self, rank, world_size): - """Due to some issues with `torch.distributed.barrier()`, we have to - implement this ugly barrier function. - """ - if rank == 0: - for i in range(1, world_size): - tmp = osp.join(self.lock_dir, '{}.pkl'.format(i)) - while not (osp.exists(tmp)): - time.sleep(1) - for i in range(1, world_size): - tmp = osp.join(self.lock_dir, '{}.pkl'.format(i)) - os.remove(tmp) - else: - tmp = osp.join(self.lock_dir, '{}.pkl'.format(rank)) - mmcv.dump([], tmp) - while osp.exists(tmp): - time.sleep(1) - - def before_run(self, runner): - self.lock_dir = osp.join(runner.work_dir, '.lock_map_hook') - if runner.rank == 0: - if osp.exists(self.lock_dir): - shutil.rmtree(self.lock_dir) - mmcv.mkdir_or_exist(self.lock_dir) - - def after_run(self, runner): - if runner.rank == 0: - shutil.rmtree(self.lock_dir) def after_train_epoch(self, runner): if not self.every_n_epochs(runner, self.interval): @@ -84,7 +53,7 @@ class DistEvalHook(Hook): if runner.rank == 0: print('\n') - self._barrier(runner.rank, runner.world_size) + dist.barrier() for i in range(1, runner.world_size): tmp_file = osp.join(runner.work_dir, 'temp_{}.pkl'.format(i)) tmp_results = mmcv.load(tmp_file) @@ -96,8 +65,8 @@ class DistEvalHook(Hook): tmp_file = osp.join(runner.work_dir, 'temp_{}.pkl'.format(runner.rank)) mmcv.dump(results, tmp_file) - self._barrier(runner.rank, runner.world_size) - self._barrier(runner.rank, runner.world_size) + dist.barrier() + dist.barrier() def evaluate(self): raise NotImplementedError -- GitLab