Skip to content
Snippets Groups Projects
Commit bc94b212 authored by Demetris Marnerides's avatar Demetris Marnerides Committed by Kai Chen
Browse files

Allowing validation dataset for computing validation loss (#1093)

* Allowing validation dataset for computing validation loss

* added validation dataset

* fixing datasets errors

* Fixing linting errors
parent 043efc31
No related branches found
No related tags found
No related merge requests found
......@@ -136,12 +136,11 @@ def build_optimizer(model, optimizer_cfg):
def _dist_train(model, dataset, cfg, validate=False):
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
dataset,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
dist=True)
ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True)
for ds in dataset
]
# put model on gpus
model = MMDistributedDataParallel(model.cuda())
......@@ -189,13 +188,14 @@ def _dist_train(model, dataset, cfg, validate=False):
def _non_dist_train(model, dataset, cfg, validate=False):
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
dataset,
ds,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False)
dist=False) for ds in dataset
]
# put model on gpus
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
......
......@@ -83,19 +83,21 @@ def main():
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
train_dataset = build_dataset(cfg.data.train)
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
datasets.append(build_dataset(cfg.data.val))
if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__,
config=cfg.text,
CLASSES=train_dataset.CLASSES)
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
model.CLASSES = train_dataset.CLASSES
model.CLASSES = datasets[0].CLASSES
train_detector(
model,
train_dataset,
datasets,
cfg,
distributed=distributed,
validate=args.validate,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment