diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py index 222fc8319f1d2d07c4cfdfd7a246131b649489cc..c8c37edcd1f9e00f23a0d32bbd2a0f4c367de041 100644 --- a/mmdet/apis/inference.py +++ b/mmdet/apis/inference.py @@ -34,7 +34,7 @@ def init_detector(config, checkpoint=None, device='cuda:0'): if checkpoint is not None: checkpoint = load_checkpoint(model, checkpoint) if 'CLASSES' in checkpoint['meta']: - model.CLASSES = checkpoint['meta']['classes'] + model.CLASSES = checkpoint['meta']['CLASSES'] else: warnings.warn('Class names are not saved in the checkpoint\'s ' 'meta data, use COCO classes by default.') diff --git a/tools/train.py b/tools/train.py index 3e06d6f56cd6109f6788e84fd7284a852809a612..b8f21d11fa687ffd4e9c5aaf8b5d46b5142c826e 100644 --- a/tools/train.py +++ b/tools/train.py @@ -79,7 +79,7 @@ def main(): cfg.checkpoint_config.meta = dict( mmdet_version=__version__, config=cfg.text, - classes=train_dataset.CLASSES) + CLASSES=train_dataset.CLASSES) # add an attribute for visualization convenience model.CLASSES = train_dataset.CLASSES train_detector(