diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
index 47320c69eaa2557730aa3d9c34a82c3913bf6df0..c146b04feab9e354fb7953a253a2711e7c778ca2 100644
--- a/mmdet/apis/train.py
+++ b/mmdet/apis/train.py
@@ -24,15 +24,24 @@ def set_random_seed(seed):
     torch.cuda.manual_seed_all(seed)
 
 
-def get_root_logger(log_level=logging.INFO):
-    logger = logging.getLogger()
-    if not logger.hasHandlers():
-        logging.basicConfig(
-            format='%(asctime)s - %(levelname)s - %(message)s',
-            level=log_level)
+def get_root_logger(log_file=None, log_level=logging.INFO):
+    logger = logging.getLogger('mmdet')
+    # if the logger has been initialized, just return it
+    if logger.hasHandlers():
+        return logger
+
+    logging.basicConfig(
+        format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
     rank, _ = get_dist_info()
     if rank != 0:
         logger.setLevel('ERROR')
+    elif log_file is not None:
+        file_handler = logging.FileHandler(log_file, 'w')
+        file_handler.setFormatter(
+            logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
+        file_handler.setLevel(log_level)
+        logger.addHandler(file_handler)
+
     return logger
 
 
@@ -75,15 +84,26 @@ def train_detector(model,
                    cfg,
                    distributed=False,
                    validate=False,
-                   logger=None):
-    if logger is None:
-        logger = get_root_logger(cfg.log_level)
+                   timestamp=None):
+    logger = get_root_logger(cfg.log_level)
 
     # start training
     if distributed:
-        _dist_train(model, dataset, cfg, validate=validate)
+        _dist_train(
+            model,
+            dataset,
+            cfg,
+            validate=validate,
+            logger=logger,
+            timestamp=timestamp)
     else:
-        _non_dist_train(model, dataset, cfg, validate=validate)
+        _non_dist_train(
+            model,
+            dataset,
+            cfg,
+            validate=validate,
+            logger=logger,
+            timestamp=timestamp)
 
 
 def build_optimizer(model, optimizer_cfg):
@@ -166,7 +186,12 @@ def build_optimizer(model, optimizer_cfg):
         return optimizer_cls(params, **optimizer_cfg)
 
 
-def _dist_train(model, dataset, cfg, validate=False):
+def _dist_train(model,
+                dataset,
+                cfg,
+                validate=False,
+                logger=None,
+                timestamp=None):
     # prepare data loaders
     dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
     data_loaders = [
@@ -179,8 +204,10 @@ def _dist_train(model, dataset, cfg, validate=False):
 
     # build runner
     optimizer = build_optimizer(model, cfg.optimizer)
-    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
-                    cfg.log_level)
+    runner = Runner(
+        model, batch_processor, optimizer, cfg.work_dir, logger=logger)
+    # an ugly walkaround to make the .log and .log.json filenames the same
+    runner.timestamp = timestamp
 
     # fp16 setting
     fp16_cfg = cfg.get('fp16', None)
@@ -218,7 +245,12 @@ def _dist_train(model, dataset, cfg, validate=False):
     runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
 
 
-def _non_dist_train(model, dataset, cfg, validate=False):
+def _non_dist_train(model,
+                    dataset,
+                    cfg,
+                    validate=False,
+                    logger=None,
+                    timestamp=None):
     if validate:
         raise NotImplementedError('Built-in validation is not implemented '
                                   'yet in not-distributed training. Use '
@@ -239,8 +271,10 @@ def _non_dist_train(model, dataset, cfg, validate=False):
 
     # build runner
     optimizer = build_optimizer(model, cfg.optimizer)
-    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
-                    cfg.log_level)
+    runner = Runner(
+        model, batch_processor, optimizer, cfg.work_dir, logger=logger)
+    # an ugly walkaround to make the .log and .log.json filenames the same
+    runner.timestamp = timestamp
     # fp16 setting
     fp16_cfg = cfg.get('fp16', None)
     if fp16_cfg is not None:
diff --git a/mmdet/models/backbones/hrnet.py b/mmdet/models/backbones/hrnet.py
index c73fed01d22cc0f174f349e5d097b1ceda9b2867..7b7b469041b8ccb497f4bd9426e85699b2067e6f 100644
--- a/mmdet/models/backbones/hrnet.py
+++ b/mmdet/models/backbones/hrnet.py
@@ -1,5 +1,3 @@
-import logging
-
 import torch.nn as nn
 from mmcv.cnn import constant_init, kaiming_init
 from mmcv.runner import load_checkpoint
@@ -462,7 +460,8 @@ class HRNet(nn.Module):
 
     def init_weights(self, pretrained=None):
         if isinstance(pretrained, str):
-            logger = logging.getLogger()
+            from mmdet.apis import get_root_logger
+            logger = get_root_logger()
             load_checkpoint(self, pretrained, strict=False, logger=logger)
         elif pretrained is None:
             for m in self.modules():
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index ac14bc2d2d39cf27ffd1d0893ae1dbe1831b0e0a..3343c5c504e39b72dcd118bbcc513148e39a7e5b 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -1,5 +1,3 @@
-import logging
-
 import torch.nn as nn
 import torch.utils.checkpoint as cp
 from mmcv.cnn import constant_init, kaiming_init
@@ -495,7 +493,8 @@ class ResNet(nn.Module):
 
     def init_weights(self, pretrained=None):
         if isinstance(pretrained, str):
-            logger = logging.getLogger()
+            from mmdet.apis import get_root_logger
+            logger = get_root_logger()
             load_checkpoint(self, pretrained, strict=False, logger=logger)
         elif pretrained is None:
             for m in self.modules():
diff --git a/mmdet/models/backbones/ssd_vgg.py b/mmdet/models/backbones/ssd_vgg.py
index b199444b9c128935517f00d113f4e478e71bea77..8cbe42cca9cb37855b1433a45a5a8110465f543c 100644
--- a/mmdet/models/backbones/ssd_vgg.py
+++ b/mmdet/models/backbones/ssd_vgg.py
@@ -1,5 +1,3 @@
-import logging
-
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -75,7 +73,8 @@ class SSDVGG(VGG):
 
     def init_weights(self, pretrained=None):
         if isinstance(pretrained, str):
-            logger = logging.getLogger()
+            from mmdet.apis import get_root_logger
+            logger = get_root_logger()
             load_checkpoint(self, pretrained, strict=False, logger=logger)
         elif pretrained is None:
             for m in self.features.modules():
diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
index 6a33dab4c59ec8551dc3fa5e5a8965954c063d56..7d0929a760b742662563e88b6a0d8d20e6a3693d 100644
--- a/mmdet/models/detectors/base.py
+++ b/mmdet/models/detectors/base.py
@@ -1,4 +1,3 @@
-import logging
 from abc import ABCMeta, abstractmethod
 
 import mmcv
@@ -9,11 +8,9 @@ import torch.nn as nn
 from mmdet.core import auto_fp16, get_classes, tensor2imgs
 
 
-class BaseDetector(nn.Module):
+class BaseDetector(nn.Module, metaclass=ABCMeta):
     """Base class for detectors"""
 
-    __metaclass__ = ABCMeta
-
     def __init__(self):
         super(BaseDetector, self).__init__()
         self.fp16_enabled = False
@@ -61,9 +58,8 @@ class BaseDetector(nn.Module):
         """
         pass
 
-    @abstractmethod
     async def async_simple_test(self, img, img_meta, **kwargs):
-        pass
+        raise NotImplementedError
 
     @abstractmethod
     def simple_test(self, img, img_meta, **kwargs):
@@ -75,7 +71,8 @@ class BaseDetector(nn.Module):
 
     def init_weights(self, pretrained=None):
         if pretrained is not None:
-            logger = logging.getLogger()
+            from mmdet.apis import get_root_logger
+            logger = get_root_logger()
             logger.info('load model from: {}'.format(pretrained))
 
     async def aforward_test(self, *, img, img_meta, **kwargs):
diff --git a/mmdet/models/shared_heads/res_layer.py b/mmdet/models/shared_heads/res_layer.py
index cbc77ac98d72b2a3c28c06e8feb4cc9aa3a109c7..33b962bb7a12d521587803253d46dcb2a787a74c 100644
--- a/mmdet/models/shared_heads/res_layer.py
+++ b/mmdet/models/shared_heads/res_layer.py
@@ -1,5 +1,3 @@
-import logging
-
 import torch.nn as nn
 from mmcv.cnn import constant_init, kaiming_init
 from mmcv.runner import load_checkpoint
@@ -47,7 +45,8 @@ class ResLayer(nn.Module):
 
     def init_weights(self, pretrained=None):
         if isinstance(pretrained, str):
-            logger = logging.getLogger()
+            from mmdet.apis import get_root_logger
+            logger = get_root_logger()
             load_checkpoint(self, pretrained, strict=False, logger=logger)
         elif pretrained is None:
             for m in self.modules():
diff --git a/tools/train.py b/tools/train.py
index e3bbbde6a0b76d399a04dac256f9747f480fdb64..5958d2409b810c344c867a59ff173a58d8b881d3 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -1,7 +1,10 @@
 from __future__ import division
 import argparse
 import os
+import os.path as osp
+import time
 
+import mmcv
 import torch
 from mmcv import Config
 from mmcv.runner import init_dist
@@ -71,11 +74,17 @@ def main():
         distributed = True
         init_dist(args.launcher, **cfg.dist_params)
 
-    # init logger before other steps
-    logger = get_root_logger(cfg.log_level)
+    # create work_dir
+    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
+    # init the logger before other steps
+    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+    log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
+    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
+
+    # log some basic info
     logger.info('Distributed training: {}'.format(distributed))
     logger.info('MMDetection Version: {}'.format(__version__))
-    logger.info('Config: {}'.format(cfg.text))
+    logger.info('Config:\n{}'.format(cfg.text))
 
     # set random seeds
     if args.seed is not None:
@@ -103,7 +112,7 @@ def main():
         cfg,
         distributed=distributed,
         validate=args.validate,
-        logger=logger)
+        timestamp=timestamp)
 
 
 if __name__ == '__main__':