diff --git a/tools/test.py b/tools/test.py index 9719f45e3f881a7632bbd7f99187efe3181b35a3..ada5e607d2ca8bc02681c2d999be913e8328d5bd 100644 --- a/tools/test.py +++ b/tools/test.py @@ -63,8 +63,10 @@ def collect_results(result_part, size, tmpdir=None): if tmpdir is None: MAX_LEN = 512 # 32 is whitespace - dir_tensor = torch.full( - (MAX_LEN, ), 32, dtype=torch.uint8, device='cuda') + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') if rank == 0: tmpdir = tempfile.mkdtemp() tmpdir = torch.tensor( @@ -152,7 +154,13 @@ def main(): # build the model and load checkpoint model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) - load_checkpoint(model, args.checkpoint, map_location='cpu') + checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') + # old versions did not save class info in checkpoints, this walkaround is + # for backward compatibility + if 'CLASSES' in checkpoint['meta']: + model.CLASSES = checkpoint['meta']['CLASSES'] + else: + model.CLASSES = dataset.CLASSES if not distributed: model = MMDataParallel(model, device_ids=[0])