Skip to content
Snippets Groups Projects
Commit fd67644c authored by Kai Chen's avatar Kai Chen
Browse files

use torch.distributed.barrier() instead of the self-implemented one

parent 64b1c8b6
No related branches found
No related tags found
No related merge requests found
import os
import os.path as osp
import shutil
import time
import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import Hook, obj_from_dict
from mmcv.parallel import scatter, collate
from pycocotools.cocoeval import COCOeval
......@@ -29,36 +28,6 @@ class DistEvalHook(Hook):
'dataset must be a Dataset object or a dict, not {}'.format(
type(dataset)))
self.interval = interval
self.lock_dir = None
def _barrier(self, rank, world_size):
"""Due to some issues with `torch.distributed.barrier()`, we have to
implement this ugly barrier function.
"""
if rank == 0:
for i in range(1, world_size):
tmp = osp.join(self.lock_dir, '{}.pkl'.format(i))
while not (osp.exists(tmp)):
time.sleep(1)
for i in range(1, world_size):
tmp = osp.join(self.lock_dir, '{}.pkl'.format(i))
os.remove(tmp)
else:
tmp = osp.join(self.lock_dir, '{}.pkl'.format(rank))
mmcv.dump([], tmp)
while osp.exists(tmp):
time.sleep(1)
def before_run(self, runner):
self.lock_dir = osp.join(runner.work_dir, '.lock_map_hook')
if runner.rank == 0:
if osp.exists(self.lock_dir):
shutil.rmtree(self.lock_dir)
mmcv.mkdir_or_exist(self.lock_dir)
def after_run(self, runner):
if runner.rank == 0:
shutil.rmtree(self.lock_dir)
def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
......@@ -84,7 +53,7 @@ class DistEvalHook(Hook):
if runner.rank == 0:
print('\n')
self._barrier(runner.rank, runner.world_size)
dist.barrier()
for i in range(1, runner.world_size):
tmp_file = osp.join(runner.work_dir, 'temp_{}.pkl'.format(i))
tmp_results = mmcv.load(tmp_file)
......@@ -96,8 +65,8 @@ class DistEvalHook(Hook):
tmp_file = osp.join(runner.work_dir,
'temp_{}.pkl'.format(runner.rank))
mmcv.dump(results, tmp_file)
self._barrier(runner.rank, runner.world_size)
self._barrier(runner.rank, runner.world_size)
dist.barrier()
dist.barrier()
def evaluate(self):
raise NotImplementedError
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment