diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index f68de4f2f6078017d080fae196cbd07a7a6b99f2..97c0dc69ebf1b1815ef99bfd51aa7f60ca4e9fd0 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -1,4 +1,3 @@ -import logging import random import re from collections import OrderedDict @@ -7,35 +6,14 @@ import numpy as np import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import (DistSamplerSeedHook, Runner, get_dist_info, - obj_from_dict) +from mmcv.runner import DistSamplerSeedHook, Runner, obj_from_dict from mmdet import datasets from mmdet.core import (CocoDistEvalmAPHook, CocoDistEvalRecallHook, DistEvalmAPHook, DistOptimizerHook, Fp16OptimizerHook) from mmdet.datasets import DATASETS, build_dataloader from mmdet.models import RPN - - -def get_root_logger(log_file=None, log_level=logging.INFO): - logger = logging.getLogger('mmdet') - # if the logger has been initialized, just return it - if logger.hasHandlers(): - return logger - - logging.basicConfig( - format='%(asctime)s - %(levelname)s - %(message)s', level=log_level) - rank, _ = get_dist_info() - if rank != 0: - logger.setLevel('ERROR') - elif log_file is not None: - file_handler = logging.FileHandler(log_file, 'w') - file_handler.setFormatter( - logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) - file_handler.setLevel(log_level) - logger.addHandler(file_handler) - - return logger +from mmdet.utils import get_root_logger def set_random_seed(seed, deterministic=False): diff --git a/mmdet/core/evaluation/mean_ap.py b/mmdet/core/evaluation/mean_ap.py index eb877ee916dd2bbd5c7dbdb594ffa8d3a808ce0b..4e3cd5d07654d9cd7b9ba4f7b55b71f2f34f31c9 100644 --- a/mmdet/core/evaluation/mean_ap.py +++ b/mmdet/core/evaluation/mean_ap.py @@ -1,10 +1,10 @@ -import logging from multiprocessing import Pool import mmcv import numpy as np from terminaltables import AsciiTable +from mmdet.utils import print_log from .bbox_overlaps import bbox_overlaps from .class_names import get_classes @@ -268,7 +268,7 @@ def eval_map(det_results, scale_ranges=None, iou_thr=0.5, dataset=None, - logger='default', + logger=None, nproc=4): """Evaluate mAP of a dataset. @@ -291,11 +291,8 @@ def eval_map(det_results, dataset (list[str] | str | None): Dataset name or dataset classes, there are minor differences in metrics for different datsets, e.g. "voc07", "imagenet_det", etc. Default: None. - logger (logging.Logger | 'print' | None): The way to print the mAP - summary. If a Logger is specified, then the summary will be logged - with `logger.info()`; if set to "print", then it will be simply - printed to stdout; if set to None, then no information will be - printed. Default: 'print'. + logger (logging.Logger | str | None): The way to print the mAP + summary. See `mmdet.utils.print_log()` for details. Default: None. nproc (int): Processes used for computing TP and FP. Default: 4. @@ -383,9 +380,9 @@ def eval_map(det_results, if cls_result['num_gts'] > 0: aps.append(cls_result['ap']) mean_ap = np.array(aps).mean().item() if aps else 0.0 - if logger is not None: - print_map_summary( - mean_ap, eval_results, dataset, area_ranges, logger=logger) + + print_map_summary( + mean_ap, eval_results, dataset, area_ranges, logger=logger) return mean_ap, eval_results @@ -405,18 +402,12 @@ def print_map_summary(mean_ap, results (list[dict]): Calculated from `eval_map()`. dataset (list[str] | str | None): Dataset name or dataset classes. scale_ranges (list[tuple] | None): Range of scales to be evaluated. - logger (logging.Logger | 'print' | None): The way to print the mAP - summary. If a Logger is specified, then the summary will be logged - with `logger.info()`; if set to "print", then it will be simply - printed to stdout; if set to None, then no information will be - printed. Default: 'print'. + logger (logging.Logger | str | None): The way to print the mAP + summary. See `mmdet.utils.print_log()` for details. Default: None. """ - def _print(content): - if logger == 'print': - print(content) - elif isinstance(logger, logging.Logger): - logger.info(content) + if logger == 'silent': + return if isinstance(results[0]['ap'], np.ndarray): num_scales = len(results[0]['ap']) @@ -426,9 +417,6 @@ def print_map_summary(mean_ap, if scale_ranges is not None: assert len(scale_ranges) == num_scales - assert logger is None or logger == 'print' or isinstance( - logger, logging.Logger) - num_classes = len(results) recalls = np.zeros((num_scales, num_classes), dtype=np.float32) @@ -453,7 +441,7 @@ def print_map_summary(mean_ap, header = ['class', 'gts', 'dets', 'recall', 'ap'] for i in range(num_scales): if scale_ranges is not None: - _print('Scale range ', scale_ranges[i]) + print_log('Scale range {}'.format(scale_ranges[i]), logger=logger) table_data = [header] for j in range(num_classes): row_data = [ @@ -464,4 +452,4 @@ def print_map_summary(mean_ap, table_data.append(['mAP', '', '', '', '{:.3f}'.format(mean_ap[i])]) table = AsciiTable(table_data) table.inner_footing_row_border = True - _print('\n' + table.table) + print_log('\n' + table.table, logger=logger) diff --git a/mmdet/models/backbones/hrnet.py b/mmdet/models/backbones/hrnet.py index 7b7b469041b8ccb497f4bd9426e85699b2067e6f..0f7a082cf0e5b00afcd67f078daf0a819f63ee8c 100644 --- a/mmdet/models/backbones/hrnet.py +++ b/mmdet/models/backbones/hrnet.py @@ -3,6 +3,7 @@ from mmcv.cnn import constant_init, kaiming_init from mmcv.runner import load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm +from mmdet.utils import get_root_logger from ..registry import BACKBONES from ..utils import build_conv_layer, build_norm_layer from .resnet import BasicBlock, Bottleneck @@ -460,7 +461,6 @@ class HRNet(nn.Module): def init_weights(self, pretrained=None): if isinstance(pretrained, str): - from mmdet.apis import get_root_logger logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py index d2f3b7c517eae3c0c6efeb9d7e4ab65389c840aa..ab6913e82720477940b2dfbfb6d144a1529740a2 100644 --- a/mmdet/models/backbones/resnet.py +++ b/mmdet/models/backbones/resnet.py @@ -6,6 +6,7 @@ from torch.nn.modules.batchnorm import _BatchNorm from mmdet.models.plugins import GeneralizedAttention from mmdet.ops import ContextBlock +from mmdet.utils import get_root_logger from ..registry import BACKBONES from ..utils import build_conv_layer, build_norm_layer @@ -468,7 +469,6 @@ class ResNet(nn.Module): def init_weights(self, pretrained=None): if isinstance(pretrained, str): - from mmdet.apis import get_root_logger logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: diff --git a/mmdet/models/backbones/ssd_vgg.py b/mmdet/models/backbones/ssd_vgg.py index 8cbe42cca9cb37855b1433a45a5a8110465f543c..c7615e2a70dcd3f4a153a4dd01038412acdb94bf 100644 --- a/mmdet/models/backbones/ssd_vgg.py +++ b/mmdet/models/backbones/ssd_vgg.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from mmcv.cnn import VGG, constant_init, kaiming_init, normal_init, xavier_init from mmcv.runner import load_checkpoint +from mmdet.utils import get_root_logger from ..registry import BACKBONES @@ -73,7 +74,6 @@ class SSDVGG(VGG): def init_weights(self, pretrained=None): if isinstance(pretrained, str): - from mmdet.apis import get_root_logger logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index 60662929450a60e456e224f012986c05357cc6b6..8d38b23fc5eb01db1e62384fa52f8ada4b795113 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -6,6 +6,7 @@ import pycocotools.mask as maskUtils import torch.nn as nn from mmdet.core import auto_fp16, get_classes, tensor2imgs +from mmdet.utils import print_log class BaseDetector(nn.Module, metaclass=ABCMeta): @@ -71,9 +72,7 @@ class BaseDetector(nn.Module, metaclass=ABCMeta): def init_weights(self, pretrained=None): if pretrained is not None: - from mmdet.apis import get_root_logger - logger = get_root_logger() - logger.info('load model from: {}'.format(pretrained)) + print_log('load model from: {}'.format(pretrained), logger='root') async def aforward_test(self, *, img, img_meta, **kwargs): for var, name in [(img, 'img'), (img_meta, 'img_meta')]: diff --git a/mmdet/models/shared_heads/res_layer.py b/mmdet/models/shared_heads/res_layer.py index 33b962bb7a12d521587803253d46dcb2a787a74c..e1a1ba0d76b34d6199ba397916e6b29ade8e0a74 100644 --- a/mmdet/models/shared_heads/res_layer.py +++ b/mmdet/models/shared_heads/res_layer.py @@ -3,6 +3,7 @@ from mmcv.cnn import constant_init, kaiming_init from mmcv.runner import load_checkpoint from mmdet.core import auto_fp16 +from mmdet.utils import get_root_logger from ..backbones import ResNet, make_res_layer from ..registry import SHARED_HEADS @@ -45,7 +46,6 @@ class ResLayer(nn.Module): def init_weights(self, pretrained=None): if isinstance(pretrained, str): - from mmdet.apis import get_root_logger logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: diff --git a/mmdet/ops/dcn/deform_conv.py b/mmdet/ops/dcn/deform_conv.py index 24e5b08cbfd2a9581fae59fe56ed419ea9441e31..5ba5a5e8fc0c810006282b6072227480b759c1f5 100644 --- a/mmdet/ops/dcn/deform_conv.py +++ b/mmdet/ops/dcn/deform_conv.py @@ -6,6 +6,7 @@ from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair, _single +from mmdet.utils import print_log from . import deform_conv_cuda @@ -297,10 +298,10 @@ class DeformConvPack(DeformConv): '_offset.bias') if version is not None and version > 1: - from mmdet.apis import get_root_logger - logger = get_root_logger() - logger.info('DeformConvPack {} is upgraded to version 2.'.format( - prefix.rstrip('.'))) + print_log( + 'DeformConvPack {} is upgraded to version 2.'.format( + prefix.rstrip('.')), + logger='root') super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, @@ -420,11 +421,10 @@ class ModulatedDeformConvPack(ModulatedDeformConv): '_offset.bias') if version is not None and version > 1: - from mmdet.apis import get_root_logger - logger = get_root_logger() - logger.info( + print_log( 'ModulatedDeformConvPack {} is upgraded to version 2.'.format( - prefix.rstrip('.'))) + prefix.rstrip('.')), + logger='root') super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py index f65e3b2fbcf1d87c6c184f4ef63e9ad40537a23e..537a34a13ac6f7b210fce536a3f38aec0c83e203 100644 --- a/mmdet/utils/__init__.py +++ b/mmdet/utils/__init__.py @@ -1,4 +1,8 @@ from .flops_counter import get_model_complexity_info +from .logger import get_root_logger, print_log from .registry import Registry, build_from_cfg -__all__ = ['Registry', 'build_from_cfg', 'get_model_complexity_info'] +__all__ = [ + 'Registry', 'build_from_cfg', 'get_model_complexity_info', + 'get_root_logger', 'print_log' +] diff --git a/mmdet/utils/logger.py b/mmdet/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..3e6a1396b95fe3e66530cfd4144a241b0c692b6f --- /dev/null +++ b/mmdet/utils/logger.py @@ -0,0 +1,66 @@ +import logging + +from mmcv.runner import get_dist_info + + +def get_root_logger(log_file=None, log_level=logging.INFO): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name, + e.g., "mmdet". + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(__name__.split('.')[0]) # i.e., mmdet + # if the logger has been initialized, just return it + if logger.hasHandlers(): + return logger + + format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + logging.basicConfig(format=format_str, level=log_level) + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + file_handler = logging.FileHandler(log_file, 'w') + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + + return logger + + +def print_log(msg, logger=None, level=logging.INFO): + """Print a log message. + + Args: + msg (str): The message to be logged. + logger (logging.Logger | str | None): The logger to be used. Some + special loggers are: + - "root": the root logger obtained with `get_root_logger()`. + - "silent": no message will be printed. + - None: The `print()` method will be used to print log messages. + level (int): Logging level. Only available when `logger` is a Logger + object or "root". + """ + if logger is None: + print(msg) + elif logger == 'root': + _logger = get_root_logger() + _logger.log(level, msg) + elif isinstance(logger, logging.Logger): + logger.log(level, msg) + elif logger != 'silent': + raise TypeError( + 'logger should be either a logging.Logger object, "root", ' + '"silent" or None, but got {}'.format(logger)) diff --git a/tools/train.py b/tools/train.py index 2931a28c92858bdd3cbb4fe866e7a291bd5f7f7d..7f89795d5d6a2660db2b4bfb04123134cd210481 100644 --- a/tools/train.py +++ b/tools/train.py @@ -10,9 +10,10 @@ from mmcv import Config from mmcv.runner import init_dist from mmdet import __version__ -from mmdet.apis import get_root_logger, set_random_seed, train_detector +from mmdet.apis import set_random_seed, train_detector from mmdet.datasets import build_dataset from mmdet.models import build_detector +from mmdet.utils import get_root_logger def parse_args():