diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index 47320c69eaa2557730aa3d9c34a82c3913bf6df0..c146b04feab9e354fb7953a253a2711e7c778ca2 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -24,15 +24,24 @@ def set_random_seed(seed): torch.cuda.manual_seed_all(seed) -def get_root_logger(log_level=logging.INFO): - logger = logging.getLogger() - if not logger.hasHandlers(): - logging.basicConfig( - format='%(asctime)s - %(levelname)s - %(message)s', - level=log_level) +def get_root_logger(log_file=None, log_level=logging.INFO): + logger = logging.getLogger('mmdet') + # if the logger has been initialized, just return it + if logger.hasHandlers(): + return logger + + logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(message)s', level=log_level) rank, _ = get_dist_info() if rank != 0: logger.setLevel('ERROR') + elif log_file is not None: + file_handler = logging.FileHandler(log_file, 'w') + file_handler.setFormatter( + logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + return logger @@ -75,15 +84,26 @@ def train_detector(model, cfg, distributed=False, validate=False, - logger=None): - if logger is None: - logger = get_root_logger(cfg.log_level) + timestamp=None): + logger = get_root_logger(cfg.log_level) # start training if distributed: - _dist_train(model, dataset, cfg, validate=validate) + _dist_train( + model, + dataset, + cfg, + validate=validate, + logger=logger, + timestamp=timestamp) else: - _non_dist_train(model, dataset, cfg, validate=validate) + _non_dist_train( + model, + dataset, + cfg, + validate=validate, + logger=logger, + timestamp=timestamp) def build_optimizer(model, optimizer_cfg): @@ -166,7 +186,12 @@ def build_optimizer(model, optimizer_cfg): return optimizer_cls(params, **optimizer_cfg) -def _dist_train(model, dataset, cfg, validate=False): +def _dist_train(model, + dataset, + cfg, + validate=False, + logger=None, + timestamp=None): # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] data_loaders = [ @@ -179,8 +204,10 @@ def _dist_train(model, dataset, cfg, validate=False): # build runner optimizer = build_optimizer(model, cfg.optimizer) - runner = Runner(model, batch_processor, optimizer, cfg.work_dir, - cfg.log_level) + runner = Runner( + model, batch_processor, optimizer, cfg.work_dir, logger=logger) + # an ugly walkaround to make the .log and .log.json filenames the same + runner.timestamp = timestamp # fp16 setting fp16_cfg = cfg.get('fp16', None) @@ -218,7 +245,12 @@ def _dist_train(model, dataset, cfg, validate=False): runner.run(data_loaders, cfg.workflow, cfg.total_epochs) -def _non_dist_train(model, dataset, cfg, validate=False): +def _non_dist_train(model, + dataset, + cfg, + validate=False, + logger=None, + timestamp=None): if validate: raise NotImplementedError('Built-in validation is not implemented ' 'yet in not-distributed training. Use ' @@ -239,8 +271,10 @@ def _non_dist_train(model, dataset, cfg, validate=False): # build runner optimizer = build_optimizer(model, cfg.optimizer) - runner = Runner(model, batch_processor, optimizer, cfg.work_dir, - cfg.log_level) + runner = Runner( + model, batch_processor, optimizer, cfg.work_dir, logger=logger) + # an ugly walkaround to make the .log and .log.json filenames the same + runner.timestamp = timestamp # fp16 setting fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: diff --git a/mmdet/models/backbones/hrnet.py b/mmdet/models/backbones/hrnet.py index c73fed01d22cc0f174f349e5d097b1ceda9b2867..7b7b469041b8ccb497f4bd9426e85699b2067e6f 100644 --- a/mmdet/models/backbones/hrnet.py +++ b/mmdet/models/backbones/hrnet.py @@ -1,5 +1,3 @@ -import logging - import torch.nn as nn from mmcv.cnn import constant_init, kaiming_init from mmcv.runner import load_checkpoint @@ -462,7 +460,8 @@ class HRNet(nn.Module): def init_weights(self, pretrained=None): if isinstance(pretrained, str): - logger = logging.getLogger() + from mmdet.apis import get_root_logger + logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py index ac14bc2d2d39cf27ffd1d0893ae1dbe1831b0e0a..3343c5c504e39b72dcd118bbcc513148e39a7e5b 100644 --- a/mmdet/models/backbones/resnet.py +++ b/mmdet/models/backbones/resnet.py @@ -1,5 +1,3 @@ -import logging - import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import constant_init, kaiming_init @@ -495,7 +493,8 @@ class ResNet(nn.Module): def init_weights(self, pretrained=None): if isinstance(pretrained, str): - logger = logging.getLogger() + from mmdet.apis import get_root_logger + logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): diff --git a/mmdet/models/backbones/ssd_vgg.py b/mmdet/models/backbones/ssd_vgg.py index b199444b9c128935517f00d113f4e478e71bea77..8cbe42cca9cb37855b1433a45a5a8110465f543c 100644 --- a/mmdet/models/backbones/ssd_vgg.py +++ b/mmdet/models/backbones/ssd_vgg.py @@ -1,5 +1,3 @@ -import logging - import torch import torch.nn as nn import torch.nn.functional as F @@ -75,7 +73,8 @@ class SSDVGG(VGG): def init_weights(self, pretrained=None): if isinstance(pretrained, str): - logger = logging.getLogger() + from mmdet.apis import get_root_logger + logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.features.modules(): diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index 6a33dab4c59ec8551dc3fa5e5a8965954c063d56..7d0929a760b742662563e88b6a0d8d20e6a3693d 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -1,4 +1,3 @@ -import logging from abc import ABCMeta, abstractmethod import mmcv @@ -9,11 +8,9 @@ import torch.nn as nn from mmdet.core import auto_fp16, get_classes, tensor2imgs -class BaseDetector(nn.Module): +class BaseDetector(nn.Module, metaclass=ABCMeta): """Base class for detectors""" - __metaclass__ = ABCMeta - def __init__(self): super(BaseDetector, self).__init__() self.fp16_enabled = False @@ -61,9 +58,8 @@ class BaseDetector(nn.Module): """ pass - @abstractmethod async def async_simple_test(self, img, img_meta, **kwargs): - pass + raise NotImplementedError @abstractmethod def simple_test(self, img, img_meta, **kwargs): @@ -75,7 +71,8 @@ class BaseDetector(nn.Module): def init_weights(self, pretrained=None): if pretrained is not None: - logger = logging.getLogger() + from mmdet.apis import get_root_logger + logger = get_root_logger() logger.info('load model from: {}'.format(pretrained)) async def aforward_test(self, *, img, img_meta, **kwargs): diff --git a/mmdet/models/shared_heads/res_layer.py b/mmdet/models/shared_heads/res_layer.py index cbc77ac98d72b2a3c28c06e8feb4cc9aa3a109c7..33b962bb7a12d521587803253d46dcb2a787a74c 100644 --- a/mmdet/models/shared_heads/res_layer.py +++ b/mmdet/models/shared_heads/res_layer.py @@ -1,5 +1,3 @@ -import logging - import torch.nn as nn from mmcv.cnn import constant_init, kaiming_init from mmcv.runner import load_checkpoint @@ -47,7 +45,8 @@ class ResLayer(nn.Module): def init_weights(self, pretrained=None): if isinstance(pretrained, str): - logger = logging.getLogger() + from mmdet.apis import get_root_logger + logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): diff --git a/tools/train.py b/tools/train.py index e3bbbde6a0b76d399a04dac256f9747f480fdb64..5958d2409b810c344c867a59ff173a58d8b881d3 100644 --- a/tools/train.py +++ b/tools/train.py @@ -1,7 +1,10 @@ from __future__ import division import argparse import os +import os.path as osp +import time +import mmcv import torch from mmcv import Config from mmcv.runner import init_dist @@ -71,11 +74,17 @@ def main(): distributed = True init_dist(args.launcher, **cfg.dist_params) - # init logger before other steps - logger = get_root_logger(cfg.log_level) + # create work_dir + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp)) + logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + + # log some basic info logger.info('Distributed training: {}'.format(distributed)) logger.info('MMDetection Version: {}'.format(__version__)) - logger.info('Config: {}'.format(cfg.text)) + logger.info('Config:\n{}'.format(cfg.text)) # set random seeds if args.seed is not None: @@ -103,7 +112,7 @@ def main(): cfg, distributed=distributed, validate=args.validate, - logger=logger) + timestamp=timestamp) if __name__ == '__main__':