diff --git a/tools/train.py b/tools/train.py index ee2012fcbdbdfdbfd47f3d73c5092195939ee124..8c3290a89a3aa4686aed87c888136c17e3c69e9f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -35,6 +35,10 @@ def parse_args(): default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--autoscale-lr', + action='store_true', + help='automatically scale lr with the number of gpus') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) @@ -56,6 +60,10 @@ def main(): cfg.resume_from = args.resume_from cfg.gpus = args.gpus + if args.autoscale_lr: + # apply the linear scaling rule (https://arxiv.org/abs/1706.02677) + cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8 + # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False