diff --git a/tools/train.py b/tools/train.py index 742bd1bbf03e6890285fe5d5a3ed8947f4927fb8..2111546a39d7f28ffa375029106e191f89166e24 100644 --- a/tools/train.py +++ b/tools/train.py @@ -34,7 +34,7 @@ def parse_losses(losses): return loss, log_vars -def batch_processor(model, data, train_mode, args=None): +def batch_processor(model, data, train_mode): losses = model(**data) loss, log_vars = parse_losses(losses) @@ -115,7 +115,7 @@ def main(): runner.resume(cfg.resume_from) elif cfg.load_from: runner.load_checkpoint(cfg.load_from) - runner.run(data_loaders, cfg.workflow, cfg.total_epochs, args=args) + runner.run(data_loaders, cfg.workflow, cfg.total_epochs) if __name__ == '__main__':