diff --git a/tools/train.py b/tools/train.py index 2aa681afa35dfe1ad956acc7cfce666344711511..3241174da739d44b9588f152e7fa6a064c10ce56 100644 --- a/tools/train.py +++ b/tools/train.py @@ -134,12 +134,6 @@ def main(): runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, cfg.log_level) - if args.validate: - val_dataset = obj_from_dict(cfg.data.test, datasets) - data_loaders.append( - build_dataloader(val_dataset, cfg.data.imgs_per_gpu, - cfg.data.workers_per_gpu, cfg.gpus, dist)) - # register hooks optimizer_config = DistOptimizerHook( **cfg.optimizer_config) if dist else cfg.optimizer_config @@ -148,10 +142,11 @@ def main(): if dist: runner.register_hook(DistSamplerSeedHook()) # register eval hooks - if isinstance(model.module, RPN): - runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) - elif cfg.data.val.type == 'CocoDataset': - runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) + if args.validate: + if isinstance(model.module, RPN): + runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) + elif cfg.data.val.type == 'CocoDataset': + runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) if cfg.resume_from: runner.resume(cfg.resume_from)