From 8f70f9df028a4e67057c8da46aefd80f362b1738 Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Mon, 3 Jun 2019 21:33:28 +0800 Subject: [PATCH] bug fix for testing visualization (#747) --- tools/test.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tools/test.py b/tools/test.py index 9719f45..ada5e60 100644 --- a/tools/test.py +++ b/tools/test.py @@ -63,8 +63,10 @@ def collect_results(result_part, size, tmpdir=None): if tmpdir is None: MAX_LEN = 512 # 32 is whitespace - dir_tensor = torch.full( - (MAX_LEN, ), 32, dtype=torch.uint8, device='cuda') + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') if rank == 0: tmpdir = tempfile.mkdtemp() tmpdir = torch.tensor( @@ -152,7 +154,13 @@ def main(): # build the model and load checkpoint model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) - load_checkpoint(model, args.checkpoint, map_location='cpu') + checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') + # old versions did not save class info in checkpoints, this walkaround is + # for backward compatibility + if 'CLASSES' in checkpoint['meta']: + model.CLASSES = checkpoint['meta']['CLASSES'] + else: + model.CLASSES = dataset.CLASSES if not distributed: model = MMDataParallel(model, device_ids=[0]) -- GitLab