From 8c86f74ca01ec4c46997915e878b2a6ac518ff6d Mon Sep 17 00:00:00 2001
From: Kamran Melikov <melikovk@gmail.com>
Date: Tue, 21 Jan 2020 07:45:34 -0500
Subject: [PATCH] Non color images (#1976)

* First Draft

On branch non-color-images
Changes to be committed:
modified:   mmdet/datasets/pipelines/loading.py

* Add option to load non color images

Add 'color_type' parameter to LoadImageFromFile class
Change __repr__ method accordingly
Since non-color images maybe two dimensional expand image
dimensions if necessary in DefaultFormatBundle and
ImageToTensor classes

Changes to be committed:
    modified:   mmdet/datasets/pipelines/formating.py
    modified:   mmdet/datasets/pipelines/loading.py

* Fix RandomCrop to work with grayscale images

Changes to be committed:
    modified:   mmdet/datasets/pipelines/transforms.py

* Modify retrieving w, h of padded image in anchor heads

This addreses problems with single channel images for which the
shape is  tuple with 2 values

Changes to be committed:
    modified:   mmdet/models/anchor_heads/anchor_head.py
    modified:   mmdet/models/anchor_heads/guided_anchor_head.py
    modified:   mmdet/models/anchor_heads/reppoints_head.py
---
 mmdet/datasets/pipelines/formating.py           | 10 ++++++++--
 mmdet/datasets/pipelines/loading.py             |  9 +++++----
 mmdet/datasets/pipelines/transforms.py          |  2 +-
 mmdet/models/anchor_heads/anchor_head.py        |  2 +-
 mmdet/models/anchor_heads/guided_anchor_head.py |  2 +-
 mmdet/models/anchor_heads/reppoints_head.py     |  2 +-
 6 files changed, 17 insertions(+), 10 deletions(-)

diff --git a/mmdet/datasets/pipelines/formating.py b/mmdet/datasets/pipelines/formating.py
index 83385ab..e14dd0a 100644
--- a/mmdet/datasets/pipelines/formating.py
+++ b/mmdet/datasets/pipelines/formating.py
@@ -52,7 +52,10 @@ class ImageToTensor(object):
 
     def __call__(self, results):
         for key in self.keys:
-            results[key] = to_tensor(results[key].transpose(2, 0, 1))
+            img = results[key]
+            if len(img.shape) < 3:
+                img = np.expand_dims(img, -1)
+            results[key] = to_tensor(img.transpose(2, 0, 1))
         return results
 
     def __repr__(self):
@@ -115,7 +118,10 @@ class DefaultFormatBundle(object):
 
     def __call__(self, results):
         if 'img' in results:
-            img = np.ascontiguousarray(results['img'].transpose(2, 0, 1))
+            img = results['img']
+            if len(img.shape) < 3:
+                img = np.expand_dims(img, -1)
+            img = np.ascontiguousarray(img.transpose(2, 0, 1))
             results['img'] = DC(to_tensor(img), stack=True)
         for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']:
             if key not in results:
diff --git a/mmdet/datasets/pipelines/loading.py b/mmdet/datasets/pipelines/loading.py
index f4aa6de..190773b 100644
--- a/mmdet/datasets/pipelines/loading.py
+++ b/mmdet/datasets/pipelines/loading.py
@@ -10,8 +10,9 @@ from ..registry import PIPELINES
 @PIPELINES.register_module
 class LoadImageFromFile(object):
 
-    def __init__(self, to_float32=False):
+    def __init__(self, to_float32=False, color_type='color'):
         self.to_float32 = to_float32
+        self.color_type = color_type
 
     def __call__(self, results):
         if results['img_prefix'] is not None:
@@ -19,7 +20,7 @@ class LoadImageFromFile(object):
                                 results['img_info']['filename'])
         else:
             filename = results['img_info']['filename']
-        img = mmcv.imread(filename)
+        img = mmcv.imread(filename, self.color_type)
         if self.to_float32:
             img = img.astype(np.float32)
         results['filename'] = filename
@@ -29,8 +30,8 @@ class LoadImageFromFile(object):
         return results
 
     def __repr__(self):
-        return self.__class__.__name__ + '(to_float32={})'.format(
-            self.to_float32)
+        return '{} (to_float32={}, color_type={})'.format(
+            self.__class__.__name__, self.to_float32, self.color_type)
 
 
 @PIPELINES.register_module
diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py
index a9516d5..8a5c728 100644
--- a/mmdet/datasets/pipelines/transforms.py
+++ b/mmdet/datasets/pipelines/transforms.py
@@ -364,7 +364,7 @@ class RandomCrop(object):
         crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
 
         # crop the image
-        img = img[crop_y1:crop_y2, crop_x1:crop_x2, :]
+        img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
         img_shape = img.shape
         results['img'] = img
         results['img_shape'] = img_shape
diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py
index 83944c8..0fdc0aa 100644
--- a/mmdet/models/anchor_heads/anchor_head.py
+++ b/mmdet/models/anchor_heads/anchor_head.py
@@ -127,7 +127,7 @@ class AnchorHead(nn.Module):
             for i in range(num_levels):
                 anchor_stride = self.anchor_strides[i]
                 feat_h, feat_w = featmap_sizes[i]
-                h, w, _ = img_meta['pad_shape']
+                h, w = img_meta['pad_shape'][:2]
                 valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
                 valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
                 flags = self.anchor_generators[i].valid_flags(
diff --git a/mmdet/models/anchor_heads/guided_anchor_head.py b/mmdet/models/anchor_heads/guided_anchor_head.py
index 271a0dc..9fdf4f6 100644
--- a/mmdet/models/anchor_heads/guided_anchor_head.py
+++ b/mmdet/models/anchor_heads/guided_anchor_head.py
@@ -246,7 +246,7 @@ class GuidedAnchorHead(AnchorHead):
                 approxs = multi_level_approxs[i]
                 anchor_stride = self.anchor_strides[i]
                 feat_h, feat_w = featmap_sizes[i]
-                h, w, _ = img_meta['pad_shape']
+                h, w = img_meta['pad_shape'][:2]
                 valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
                 valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
                 flags = self.approx_generators[i].valid_flags(
diff --git a/mmdet/models/anchor_heads/reppoints_head.py b/mmdet/models/anchor_heads/reppoints_head.py
index 1ce7abd..b3214f3 100644
--- a/mmdet/models/anchor_heads/reppoints_head.py
+++ b/mmdet/models/anchor_heads/reppoints_head.py
@@ -320,7 +320,7 @@ class RepPointsHead(nn.Module):
             for i in range(num_levels):
                 point_stride = self.point_strides[i]
                 feat_h, feat_w = featmap_sizes[i]
-                h, w, _ = img_meta['pad_shape']
+                h, w = img_meta['pad_shape'][:2]
                 valid_feat_h = min(int(np.ceil(h / point_stride)), feat_h)
                 valid_feat_w = min(int(np.ceil(w / point_stride)), feat_w)
                 flags = self.point_generators[i].valid_flags(
-- 
GitLab