From c52cdd627833fefc6bfc5ce5654c137e73c1983b Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Tue, 14 May 2019 20:04:45 -0700 Subject: [PATCH] Support param-wise optimizer settings (#634) * support param-wise optimizer settings * fix outdated docstring * minor fix for param checking --- mmdet/apis/train.py | 104 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 88 insertions(+), 16 deletions(-) diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index 152aad8..34b6326 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -1,9 +1,10 @@ from __future__ import division +import re from collections import OrderedDict import torch -from mmcv.runner import Runner, DistSamplerSeedHook +from mmcv.runner import Runner, DistSamplerSeedHook, obj_from_dict from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmdet import datasets @@ -38,8 +39,9 @@ def batch_processor(model, data, train_mode): losses = model(**data) loss, log_vars = parse_losses(losses) - outputs = dict( - loss=loss, log_vars=log_vars, num_samples=len(data['img'].data)) + outputs = dict(loss=loss, + log_vars=log_vars, + num_samples=len(data['img'].data)) return outputs @@ -60,19 +62,89 @@ def train_detector(model, _non_dist_train(model, dataset, cfg, validate=validate) +def build_optimizer(model, optimizer_cfg): + """Build optimizer from configs. + + Args: + model (:obj:`nn.Module`): The model with parameters to be optimized. + optimizer_cfg (dict): The config dict of the optimizer. + Positional fields are: + - type: class name of the optimizer. + - lr: base learning rate. + Optional fields are: + - any arguments of the corresponding optimizer type, e.g., + weight_decay, momentum, etc. + - paramwise_options: a dict with 3 accepted fileds + (bias_lr_mult, bias_decay_mult, norm_decay_mult). + `bias_lr_mult` and `bias_decay_mult` will be multiplied to + the lr and weight decay respectively for all bias parameters + (except for the normalization layers), and + `norm_decay_mult` will be multiplied to the weight decay + for all weight and bias parameters of normalization layers. + + Returns: + torch.optim.Optimizer: The initialized optimizer. + """ + if hasattr(model, 'module'): + model = model.module + + optimizer_cfg = optimizer_cfg.copy() + paramwise_options = optimizer_cfg.pop('paramwise_options', None) + # if no paramwise option is specified, just use the global setting + if paramwise_options is None: + return obj_from_dict(optimizer_cfg, torch.optim, + dict(params=model.parameters())) + else: + assert isinstance(paramwise_options, dict) + # get base lr and weight decay + base_lr = optimizer_cfg['lr'] + base_wd = optimizer_cfg.get('weight_decay', None) + # weight_decay must be explicitly specified if mult is specified + if ('bias_decay_mult' in paramwise_options + or 'norm_decay_mult' in paramwise_options): + assert base_wd is not None + # get param-wise options + bias_lr_mult = paramwise_options.get('bias_lr_mult', 1.) + bias_decay_mult = paramwise_options.get('bias_decay_mult', 1.) + norm_decay_mult = paramwise_options.get('norm_decay_mult', 1.) + # set param-wise lr and weight decay + params = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + param_group = {'params': [param]} + # for norm layers, overwrite the weight decay of weight and bias + # TODO: obtain the norm layer prefixes dynamically + if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name): + if base_wd is not None: + param_group['weight_decay'] = base_wd * norm_decay_mult + # for other layers, overwrite both lr and weight decay of bias + elif name.endswith('.bias'): + param_group['lr'] = base_lr * bias_lr_mult + if base_wd is not None: + param_group['weight_decay'] = base_wd * bias_decay_mult + # otherwise use the global settings + + params.append(param_group) + + optimizer_cls = getattr(torch.optim, optimizer_cfg.pop('type')) + return optimizer_cls(params, **optimizer_cfg) + + def _dist_train(model, dataset, cfg, validate=False): # prepare data loaders data_loaders = [ - build_dataloader( - dataset, - cfg.data.imgs_per_gpu, - cfg.data.workers_per_gpu, - dist=True) + build_dataloader(dataset, + cfg.data.imgs_per_gpu, + cfg.data.workers_per_gpu, + dist=True) ] # put model on gpus model = MMDistributedDataParallel(model.cuda()) # build runner - runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, + optimizer = build_optimizer(model, cfg.optimizer) + runner = Runner(model, batch_processor, optimizer, cfg.work_dir, cfg.log_level) # register hooks optimizer_config = DistOptimizerHook(**cfg.optimizer_config) @@ -102,17 +174,17 @@ def _dist_train(model, dataset, cfg, validate=False): def _non_dist_train(model, dataset, cfg, validate=False): # prepare data loaders data_loaders = [ - build_dataloader( - dataset, - cfg.data.imgs_per_gpu, - cfg.data.workers_per_gpu, - cfg.gpus, - dist=False) + build_dataloader(dataset, + cfg.data.imgs_per_gpu, + cfg.data.workers_per_gpu, + cfg.gpus, + dist=False) ] # put model on gpus model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() # build runner - runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, + optimizer = build_optimizer(model, cfg.optimizer) + runner = Runner(model, batch_processor, optimizer, cfg.work_dir, cfg.log_level) runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, cfg.checkpoint_config, cfg.log_config) -- GitLab