From d88ad8fc32a6cce46c72619a16b85226360a7ba2 Mon Sep 17 00:00:00 2001
From: tyomj <32788266+tyomj@users.noreply.github.com>
Date: Sat, 5 Oct 2019 09:57:13 +0300
Subject: [PATCH] Albumentations augs wrapper (#1354)

* Albumentations wrapper

* 2 single quote format

* version >= 0.3.2
---
 configs/albu_example/mask_rcnn_r50_fpn_1x.py | 248 +++++++++++++++++++
 mmdet/datasets/pipelines/__init__.py         |   4 +-
 mmdet/datasets/pipelines/transforms.py       | 161 +++++++++++-
 requirements.txt                             |   3 +-
 4 files changed, 410 insertions(+), 6 deletions(-)
 create mode 100644 configs/albu_example/mask_rcnn_r50_fpn_1x.py

diff --git a/configs/albu_example/mask_rcnn_r50_fpn_1x.py b/configs/albu_example/mask_rcnn_r50_fpn_1x.py
new file mode 100644
index 0000000..9638fac
--- /dev/null
+++ b/configs/albu_example/mask_rcnn_r50_fpn_1x.py
@@ -0,0 +1,248 @@
+# model settings
+model = dict(
+    type='MaskRCNN',
+    pretrained='torchvision://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
+    mask_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    mask_head=dict(
+        type='FCNMaskHead',
+        num_convs=4,
+        in_channels=256,
+        conv_out_channels=256,
+        num_classes=81,
+        loss_mask=dict(
+            type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        mask_size=28,
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5))
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+albu_train_transforms = [
+    dict(
+        type='ShiftScaleRotate',
+        shift_limit=0.0625,
+        scale_limit=0.0,
+        rotate_limit=0,
+        interpolation=1,
+        p=0.5),
+    dict(
+        type='RandomBrightnessContrast',
+        brightness_limit=[0.1, 0.3],
+        contrast_limit=[0.1, 0.3],
+        p=0.2),
+    dict(
+        type='OneOf',
+        transforms=[
+            dict(
+                type='RGBShift',
+                r_shift_limit=10,
+                g_shift_limit=10,
+                b_shift_limit=10,
+                p=1.0),
+            dict(
+                type='HueSaturationValue',
+                hue_shift_limit=20,
+                sat_shift_limit=30,
+                val_shift_limit=20,
+                p=1.0)
+        ],
+        p=0.1),
+    dict(type='JpegCompression', quality_lower=85, quality_upper=95, p=0.2),
+    dict(type='ChannelShuffle', p=0.1),
+    dict(
+        type='OneOf',
+        transforms=[
+            dict(type='Blur', blur_limit=3, p=1.0),
+            dict(type='MedianBlur', blur_limit=3, p=1.0)
+        ],
+        p=0.1),
+]
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+    dict(type='Pad', size_divisor=32),
+    dict(
+        type='Albu',
+        transforms=albu_train_transforms,
+        bbox_params=dict(
+            type='BboxParams',
+            format='pascal_voc',
+            label_fields=['gt_labels'],
+            min_visibility=0.0,
+            filter_lost_elements=True),
+        keymap={
+            'img': 'image',
+            'gt_masks': 'masks',
+            'gt_bboxes': 'bboxes'
+        },
+        update_pad_shape=False,
+        skip_img_without_anno=True),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='DefaultFormatBundle'),
+    dict(
+        type='Collect',
+        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'],
+        meta_keys=('filename', 'ori_shape', 'img_shape', 'img_norm_cfg',
+                   'pad_shape', 'scale_factor'))
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(1333, 800),
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='Pad', size_divisor=32),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+evaluation = dict(interval=1)
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/mask_rcnn_r50_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet/datasets/pipelines/__init__.py b/mmdet/datasets/pipelines/__init__.py
index dd5919e..ae55b25 100644
--- a/mmdet/datasets/pipelines/__init__.py
+++ b/mmdet/datasets/pipelines/__init__.py
@@ -3,7 +3,7 @@ from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
                         Transpose, to_tensor)
 from .loading import LoadAnnotations, LoadImageFromFile, LoadProposals
 from .test_aug import MultiScaleFlipAug
-from .transforms import (Expand, MinIoURandomCrop, Normalize, Pad,
+from .transforms import (Albu, Expand, MinIoURandomCrop, Normalize, Pad,
                          PhotoMetricDistortion, RandomCrop, RandomFlip, Resize,
                          SegResizeFlipPadRescale)
 
@@ -12,5 +12,5 @@ __all__ = [
     'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
     'LoadProposals', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad',
     'RandomCrop', 'Normalize', 'SegResizeFlipPadRescale', 'MinIoURandomCrop',
-    'Expand', 'PhotoMetricDistortion'
+    'Expand', 'PhotoMetricDistortion', 'Albu'
 ]
diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py
index 60b86e4..c758c91 100644
--- a/mmdet/datasets/pipelines/transforms.py
+++ b/mmdet/datasets/pipelines/transforms.py
@@ -1,5 +1,9 @@
+import inspect
+
+import albumentations
 import mmcv
 import numpy as np
+from albumentations import Compose
 from imagecorruptions import corrupt
 from numpy import random
 
@@ -596,9 +600,8 @@ class MinIoURandomCrop(object):
 
                 # center of boxes should inside the crop img
                 center = (boxes[:, :2] + boxes[:, 2:]) / 2
-                mask = (center[:, 0] > patch[0]) * (
-                    center[:, 1] > patch[1]) * (center[:, 0] < patch[2]) * (
-                        center[:, 1] < patch[3])
+                mask = ((center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) *
+                        (center[:, 0] < patch[2]) * (center[:, 1] < patch[3]))
                 if not mask.any():
                     continue
                 boxes = boxes[mask]
@@ -651,3 +654,155 @@ class Corrupt(object):
         repr_str += '(corruption={}, severity={})'.format(
             self.corruption, self.severity)
         return repr_str
+
+
+@PIPELINES.register_module
+class Albu(object):
+
+    def __init__(self,
+                 transforms,
+                 bbox_params=None,
+                 keymap=None,
+                 update_pad_shape=False,
+                 skip_img_without_anno=False):
+        """
+        Adds custom transformations from Albumentations lib.
+        Please, visit `https://albumentations.readthedocs.io`
+        to get more information.
+
+        transforms (list): list of albu transformations
+        bbox_params (dict): bbox_params for albumentation `Compose`
+        keymap (dict): contains {'input key':'albumentation-style key'}
+        skip_img_without_anno (bool): whether to skip the image
+                                      if no ann left after aug
+        """
+
+        self.transforms = transforms
+        self.filter_lost_elements = False
+        self.update_pad_shape = update_pad_shape
+        self.skip_img_without_anno = skip_img_without_anno
+
+        # A simple workaround to remove masks without boxes
+        if (isinstance(bbox_params, dict) and 'label_fields' in bbox_params
+                and 'filter_lost_elements' in bbox_params):
+            self.filter_lost_elements = True
+            self.origin_label_fields = bbox_params['label_fields']
+            bbox_params['label_fields'] = ['idx_mapper']
+            del bbox_params['filter_lost_elements']
+
+        self.bbox_params = (
+            self.albu_builder(bbox_params) if bbox_params else None)
+        self.aug = Compose([self.albu_builder(t) for t in self.transforms],
+                           bbox_params=self.bbox_params)
+
+        if not keymap:
+            self.keymap_to_albu = {
+                'img': 'image',
+                'gt_masks': 'masks',
+                'gt_bboxes': 'bboxes'
+            }
+        else:
+            self.keymap_to_albu = keymap
+        self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
+
+    def albu_builder(self, cfg):
+        """Import a module from albumentations.
+        Inherits some of `build_from_cfg` logic.
+
+        Args:
+            cfg (dict): Config dict. It should at least contain the key "type".
+        Returns:
+            obj: The constructed object.
+        """
+        assert isinstance(cfg, dict) and "type" in cfg
+        args = cfg.copy()
+
+        obj_type = args.pop("type")
+        if mmcv.is_str(obj_type):
+            obj_cls = getattr(albumentations, obj_type)
+        elif inspect.isclass(obj_type):
+            obj_cls = obj_type
+        else:
+            raise TypeError(
+                'type must be a str or valid type, but got {}'.format(
+                    type(obj_type)))
+
+        if 'transforms' in args:
+            args['transforms'] = [
+                self.albu_builder(transform)
+                for transform in args['transforms']
+            ]
+
+        return obj_cls(**args)
+
+    @staticmethod
+    def mapper(d, keymap):
+        """
+        Dictionary mapper.
+        Renames keys according to keymap provided.
+
+        Args:
+            d (dict): old dict
+            keymap (dict): {'old_key':'new_key'}
+        Returns:
+            dict: new dict.
+        """
+        updated_dict = {}
+        for k, v in zip(d.keys(), d.values()):
+            new_k = keymap.get(k, k)
+            updated_dict[new_k] = d[k]
+        return updated_dict
+
+    def __call__(self, results):
+        # dict to albumentations format
+        results = self.mapper(results, self.keymap_to_albu)
+
+        if 'bboxes' in results:
+            # to list of boxes
+            if isinstance(results['bboxes'], np.ndarray):
+                results['bboxes'] = [x for x in results['bboxes']]
+            # add pseudo-field for filtration
+            if self.filter_lost_elements:
+                results['idx_mapper'] = np.arange(len(results['bboxes']))
+
+        results = self.aug(**results)
+
+        if 'bboxes' in results:
+            if isinstance(results['bboxes'], list):
+                results['bboxes'] = np.array(
+                    results['bboxes'], dtype=np.float32)
+
+            # filter label_fields
+            if self.filter_lost_elements:
+
+                results['idx_mapper'] = np.arange(len(results['bboxes']))
+
+                for label in self.origin_label_fields:
+                    results[label] = np.array(
+                        [results[label][i] for i in results['idx_mapper']])
+                if 'masks' in results:
+                    results['masks'] = [
+                        results['masks'][i] for i in results['idx_mapper']
+                    ]
+
+                if (not len(results['idx_mapper'])
+                        and self.skip_img_without_anno):
+                    return None
+
+        if 'gt_labels' in results:
+            if isinstance(results['gt_labels'], list):
+                results['gt_labels'] = np.array(results['gt_labels'])
+
+        # back to the original format
+        results = self.mapper(results, self.keymap_back)
+
+        # update final shape
+        if self.update_pad_shape:
+            results['pad_shape'] = results['img'].shape
+
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += '(transformations={})'.format(self.transformations)
+        return repr_str
diff --git a/requirements.txt b/requirements.txt
index f39f962..8a68f41 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,4 +6,5 @@ terminaltables
 pycocotools
 torch>=1.1
 torchvision
-imagecorruptions
\ No newline at end of file
+imagecorruptions
+albumentations>=0.3.2
\ No newline at end of file
-- 
GitLab