From fd67644c369f5aebc38747f23f86323918e5575f Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Wed, 10 Apr 2019 22:13:05 -0700
Subject: [PATCH] use torch.distributed.barrier() instead of the
 self-implemented one

---
 mmdet/core/evaluation/eval_hooks.py | 39 +++--------------------------
 1 file changed, 4 insertions(+), 35 deletions(-)

diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py
index ba6adbb..140c1ed 100644
--- a/mmdet/core/evaluation/eval_hooks.py
+++ b/mmdet/core/evaluation/eval_hooks.py
@@ -1,11 +1,10 @@
 import os
 import os.path as osp
-import shutil
-import time
 
 import mmcv
 import numpy as np
 import torch
+import torch.distributed as dist
 from mmcv.runner import Hook, obj_from_dict
 from mmcv.parallel import scatter, collate
 from pycocotools.cocoeval import COCOeval
@@ -29,36 +28,6 @@ class DistEvalHook(Hook):
                 'dataset must be a Dataset object or a dict, not {}'.format(
                     type(dataset)))
         self.interval = interval
-        self.lock_dir = None
-
-    def _barrier(self, rank, world_size):
-        """Due to some issues with `torch.distributed.barrier()`, we have to
-        implement this ugly barrier function.
-        """
-        if rank == 0:
-            for i in range(1, world_size):
-                tmp = osp.join(self.lock_dir, '{}.pkl'.format(i))
-                while not (osp.exists(tmp)):
-                    time.sleep(1)
-            for i in range(1, world_size):
-                tmp = osp.join(self.lock_dir, '{}.pkl'.format(i))
-                os.remove(tmp)
-        else:
-            tmp = osp.join(self.lock_dir, '{}.pkl'.format(rank))
-            mmcv.dump([], tmp)
-            while osp.exists(tmp):
-                time.sleep(1)
-
-    def before_run(self, runner):
-        self.lock_dir = osp.join(runner.work_dir, '.lock_map_hook')
-        if runner.rank == 0:
-            if osp.exists(self.lock_dir):
-                shutil.rmtree(self.lock_dir)
-            mmcv.mkdir_or_exist(self.lock_dir)
-
-    def after_run(self, runner):
-        if runner.rank == 0:
-            shutil.rmtree(self.lock_dir)
 
     def after_train_epoch(self, runner):
         if not self.every_n_epochs(runner, self.interval):
@@ -84,7 +53,7 @@ class DistEvalHook(Hook):
 
         if runner.rank == 0:
             print('\n')
-            self._barrier(runner.rank, runner.world_size)
+            dist.barrier()
             for i in range(1, runner.world_size):
                 tmp_file = osp.join(runner.work_dir, 'temp_{}.pkl'.format(i))
                 tmp_results = mmcv.load(tmp_file)
@@ -96,8 +65,8 @@ class DistEvalHook(Hook):
             tmp_file = osp.join(runner.work_dir,
                                 'temp_{}.pkl'.format(runner.rank))
             mmcv.dump(results, tmp_file)
-            self._barrier(runner.rank, runner.world_size)
-        self._barrier(runner.rank, runner.world_size)
+            dist.barrier()
+        dist.barrier()
 
     def evaluate(self):
         raise NotImplementedError
-- 
GitLab