From c52cdd627833fefc6bfc5ce5654c137e73c1983b Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Tue, 14 May 2019 20:04:45 -0700
Subject: [PATCH] Support param-wise optimizer settings (#634)

* support param-wise optimizer settings

* fix outdated docstring

* minor fix for param checking
---
 mmdet/apis/train.py | 104 +++++++++++++++++++++++++++++++++++++-------
 1 file changed, 88 insertions(+), 16 deletions(-)

diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
index 152aad8..34b6326 100644
--- a/mmdet/apis/train.py
+++ b/mmdet/apis/train.py
@@ -1,9 +1,10 @@
 from __future__ import division
 
+import re
 from collections import OrderedDict
 
 import torch
-from mmcv.runner import Runner, DistSamplerSeedHook
+from mmcv.runner import Runner, DistSamplerSeedHook, obj_from_dict
 from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
 
 from mmdet import datasets
@@ -38,8 +39,9 @@ def batch_processor(model, data, train_mode):
     losses = model(**data)
     loss, log_vars = parse_losses(losses)
 
-    outputs = dict(
-        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
+    outputs = dict(loss=loss,
+                   log_vars=log_vars,
+                   num_samples=len(data['img'].data))
 
     return outputs
 
@@ -60,19 +62,89 @@ def train_detector(model,
         _non_dist_train(model, dataset, cfg, validate=validate)
 
 
+def build_optimizer(model, optimizer_cfg):
+    """Build optimizer from configs.
+
+    Args:
+        model (:obj:`nn.Module`): The model with parameters to be optimized.
+        optimizer_cfg (dict): The config dict of the optimizer.
+            Positional fields are:
+                - type: class name of the optimizer.
+                - lr: base learning rate.
+            Optional fields are:
+                - any arguments of the corresponding optimizer type, e.g.,
+                  weight_decay, momentum, etc.
+                - paramwise_options: a dict with 3 accepted fileds
+                  (bias_lr_mult, bias_decay_mult, norm_decay_mult).
+                  `bias_lr_mult` and `bias_decay_mult` will be multiplied to
+                  the lr and weight decay respectively for all bias parameters
+                  (except for the normalization layers), and
+                  `norm_decay_mult` will be multiplied to the weight decay
+                  for all weight and bias parameters of normalization layers.
+
+    Returns:
+        torch.optim.Optimizer: The initialized optimizer.
+    """
+    if hasattr(model, 'module'):
+        model = model.module
+
+    optimizer_cfg = optimizer_cfg.copy()
+    paramwise_options = optimizer_cfg.pop('paramwise_options', None)
+    # if no paramwise option is specified, just use the global setting
+    if paramwise_options is None:
+        return obj_from_dict(optimizer_cfg, torch.optim,
+                             dict(params=model.parameters()))
+    else:
+        assert isinstance(paramwise_options, dict)
+        # get base lr and weight decay
+        base_lr = optimizer_cfg['lr']
+        base_wd = optimizer_cfg.get('weight_decay', None)
+        # weight_decay must be explicitly specified if mult is specified
+        if ('bias_decay_mult' in paramwise_options
+                or 'norm_decay_mult' in paramwise_options):
+            assert base_wd is not None
+        # get param-wise options
+        bias_lr_mult = paramwise_options.get('bias_lr_mult', 1.)
+        bias_decay_mult = paramwise_options.get('bias_decay_mult', 1.)
+        norm_decay_mult = paramwise_options.get('norm_decay_mult', 1.)
+        # set param-wise lr and weight decay
+        params = []
+        for name, param in model.named_parameters():
+            if not param.requires_grad:
+                continue
+
+            param_group = {'params': [param]}
+            # for norm layers, overwrite the weight decay of weight and bias
+            # TODO: obtain the norm layer prefixes dynamically
+            if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name):
+                if base_wd is not None:
+                    param_group['weight_decay'] = base_wd * norm_decay_mult
+            # for other layers, overwrite both lr and weight decay of bias
+            elif name.endswith('.bias'):
+                param_group['lr'] = base_lr * bias_lr_mult
+                if base_wd is not None:
+                    param_group['weight_decay'] = base_wd * bias_decay_mult
+            # otherwise use the global settings
+
+            params.append(param_group)
+
+        optimizer_cls = getattr(torch.optim, optimizer_cfg.pop('type'))
+        return optimizer_cls(params, **optimizer_cfg)
+
+
 def _dist_train(model, dataset, cfg, validate=False):
     # prepare data loaders
     data_loaders = [
-        build_dataloader(
-            dataset,
-            cfg.data.imgs_per_gpu,
-            cfg.data.workers_per_gpu,
-            dist=True)
+        build_dataloader(dataset,
+                         cfg.data.imgs_per_gpu,
+                         cfg.data.workers_per_gpu,
+                         dist=True)
     ]
     # put model on gpus
     model = MMDistributedDataParallel(model.cuda())
     # build runner
-    runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
+    optimizer = build_optimizer(model, cfg.optimizer)
+    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                     cfg.log_level)
     # register hooks
     optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
@@ -102,17 +174,17 @@ def _dist_train(model, dataset, cfg, validate=False):
 def _non_dist_train(model, dataset, cfg, validate=False):
     # prepare data loaders
     data_loaders = [
-        build_dataloader(
-            dataset,
-            cfg.data.imgs_per_gpu,
-            cfg.data.workers_per_gpu,
-            cfg.gpus,
-            dist=False)
+        build_dataloader(dataset,
+                         cfg.data.imgs_per_gpu,
+                         cfg.data.workers_per_gpu,
+                         cfg.gpus,
+                         dist=False)
     ]
     # put model on gpus
     model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
     # build runner
-    runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
+    optimizer = build_optimizer(model, cfg.optimizer)
+    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                     cfg.log_level)
     runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
                                    cfg.checkpoint_config, cfg.log_config)
-- 
GitLab