diff --git a/mmdet/datasets/pipelines/formating.py b/mmdet/datasets/pipelines/formating.py index 83385ab942a4946834ee345c9be713d9552ff5dc..e14dd0a97ba0e645cb54c30c2483f27e12ad3b37 100644 --- a/mmdet/datasets/pipelines/formating.py +++ b/mmdet/datasets/pipelines/formating.py @@ -52,7 +52,10 @@ class ImageToTensor(object): def __call__(self, results): for key in self.keys: - results[key] = to_tensor(results[key].transpose(2, 0, 1)) + img = results[key] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + results[key] = to_tensor(img.transpose(2, 0, 1)) return results def __repr__(self): @@ -115,7 +118,10 @@ class DefaultFormatBundle(object): def __call__(self, results): if 'img' in results: - img = np.ascontiguousarray(results['img'].transpose(2, 0, 1)) + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) results['img'] = DC(to_tensor(img), stack=True) for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']: if key not in results: diff --git a/mmdet/datasets/pipelines/loading.py b/mmdet/datasets/pipelines/loading.py index f4aa6de9f1ea4c404df093c70ead0765f7d48b9c..190773b158d7bc1f62b37f481e15b0d99e42b235 100644 --- a/mmdet/datasets/pipelines/loading.py +++ b/mmdet/datasets/pipelines/loading.py @@ -10,8 +10,9 @@ from ..registry import PIPELINES @PIPELINES.register_module class LoadImageFromFile(object): - def __init__(self, to_float32=False): + def __init__(self, to_float32=False, color_type='color'): self.to_float32 = to_float32 + self.color_type = color_type def __call__(self, results): if results['img_prefix'] is not None: @@ -19,7 +20,7 @@ class LoadImageFromFile(object): results['img_info']['filename']) else: filename = results['img_info']['filename'] - img = mmcv.imread(filename) + img = mmcv.imread(filename, self.color_type) if self.to_float32: img = img.astype(np.float32) results['filename'] = filename @@ -29,8 +30,8 @@ class LoadImageFromFile(object): return results def __repr__(self): - return self.__class__.__name__ + '(to_float32={})'.format( - self.to_float32) + return '{} (to_float32={}, color_type={})'.format( + self.__class__.__name__, self.to_float32, self.color_type) @PIPELINES.register_module diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index a9516d5f432e3932dd5a6a033ca8f725323b4a3f..8a5c72803829abae8a864e58cadd74a44f973d79 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -364,7 +364,7 @@ class RandomCrop(object): crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] # crop the image - img = img[crop_y1:crop_y2, crop_x1:crop_x2, :] + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] img_shape = img.shape results['img'] = img results['img_shape'] = img_shape diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py index 83944c88008fdc5a68e1ceda39b50bf4a08b81f6..0fdc0aadead17bc79c8fceaa852d085fcb44eedb 100644 --- a/mmdet/models/anchor_heads/anchor_head.py +++ b/mmdet/models/anchor_heads/anchor_head.py @@ -127,7 +127,7 @@ class AnchorHead(nn.Module): for i in range(num_levels): anchor_stride = self.anchor_strides[i] feat_h, feat_w = featmap_sizes[i] - h, w, _ = img_meta['pad_shape'] + h, w = img_meta['pad_shape'][:2] valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h) valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w) flags = self.anchor_generators[i].valid_flags( diff --git a/mmdet/models/anchor_heads/guided_anchor_head.py b/mmdet/models/anchor_heads/guided_anchor_head.py index 271a0dc74f38f7b132870bc3ed9b2aeef8dfddd8..9fdf4f664e59bd4e3651ce22b2a6203886bd1dca 100644 --- a/mmdet/models/anchor_heads/guided_anchor_head.py +++ b/mmdet/models/anchor_heads/guided_anchor_head.py @@ -246,7 +246,7 @@ class GuidedAnchorHead(AnchorHead): approxs = multi_level_approxs[i] anchor_stride = self.anchor_strides[i] feat_h, feat_w = featmap_sizes[i] - h, w, _ = img_meta['pad_shape'] + h, w = img_meta['pad_shape'][:2] valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h) valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w) flags = self.approx_generators[i].valid_flags( diff --git a/mmdet/models/anchor_heads/reppoints_head.py b/mmdet/models/anchor_heads/reppoints_head.py index 1ce7abd16f917c1ca07dcf8eb78bc3633eb75704..b3214f35708713ee0a9884bdb218087413e3e792 100644 --- a/mmdet/models/anchor_heads/reppoints_head.py +++ b/mmdet/models/anchor_heads/reppoints_head.py @@ -320,7 +320,7 @@ class RepPointsHead(nn.Module): for i in range(num_levels): point_stride = self.point_strides[i] feat_h, feat_w = featmap_sizes[i] - h, w, _ = img_meta['pad_shape'] + h, w = img_meta['pad_shape'][:2] valid_feat_h = min(int(np.ceil(h / point_stride)), feat_h) valid_feat_w = min(int(np.ceil(w / point_stride)), feat_w) flags = self.point_generators[i].valid_flags(