From 8f70f9df028a4e67057c8da46aefd80f362b1738 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Mon, 3 Jun 2019 21:33:28 +0800
Subject: [PATCH] bug fix for testing visualization (#747)

---
 tools/test.py | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/tools/test.py b/tools/test.py
index 9719f45..ada5e60 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -63,8 +63,10 @@ def collect_results(result_part, size, tmpdir=None):
     if tmpdir is None:
         MAX_LEN = 512
         # 32 is whitespace
-        dir_tensor = torch.full(
-            (MAX_LEN, ), 32, dtype=torch.uint8, device='cuda')
+        dir_tensor = torch.full((MAX_LEN, ),
+                                32,
+                                dtype=torch.uint8,
+                                device='cuda')
         if rank == 0:
             tmpdir = tempfile.mkdtemp()
             tmpdir = torch.tensor(
@@ -152,7 +154,13 @@ def main():
 
     # build the model and load checkpoint
     model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
-    load_checkpoint(model, args.checkpoint, map_location='cpu')
+    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
+    # old versions did not save class info in checkpoints, this walkaround is
+    # for backward compatibility
+    if 'CLASSES' in checkpoint['meta']:
+        model.CLASSES = checkpoint['meta']['CLASSES']
+    else:
+        model.CLASSES = dataset.CLASSES
 
     if not distributed:
         model = MMDataParallel(model, device_ids=[0])
-- 
GitLab