Skip to content
Snippets Groups Projects
Unverified Commit 8f70f9df authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

bug fix for testing visualization (#747)

parent 58c415a0
No related branches found
No related tags found
No related merge requests found
......@@ -63,8 +63,10 @@ def collect_results(result_part, size, tmpdir=None):
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full(
(MAX_LEN, ), 32, dtype=torch.uint8, device='cuda')
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
if rank == 0:
tmpdir = tempfile.mkdtemp()
tmpdir = torch.tensor(
......@@ -152,7 +154,13 @@ def main():
# build the model and load checkpoint
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
load_checkpoint(model, args.checkpoint, map_location='cpu')
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
# old versions did not save class info in checkpoints, this walkaround is
# for backward compatibility
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES']
else:
model.CLASSES = dataset.CLASSES
if not distributed:
model = MMDataParallel(model, device_ids=[0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment