diff --git a/configs/wider_face/README.md b/configs/wider_face/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f7392007a1ce6379aee4c5e4544111f8207fe823 --- /dev/null +++ b/configs/wider_face/README.md @@ -0,0 +1,32 @@ +## WIDER Face Dataset + +To use the WIDER Face dataset you need to download it +and extract to the `data/WIDERFace` folder. Annotation in the VOC format +can be found in this [repo](https://github.com/sovrasov/wider-face-pascal-voc-annotations.git). +You should move the annotation files from `WIDER_train_annotations` and `WIDER_val_annotations` folders +to the `Annotation` folders inside the corresponding directories `WIDER_train` and `WIDER_val`. +Also annotation lists `val.txt` and `train.txt` should be copied to `data/WIDERFace` from `WIDER_train_annotations` and `WIDER_val_annotations`. +The directory should be like this: + +``` +mmdetection +├── mmdet +├── tools +├── configs +├── data +│ ├── WIDERFace +│ │ ├── WIDER_train +│ | │ ├──0--Parade +│ | │ ├── ... +│ | │ ├── Annotations +│ │ ├── WIDER_val +│ | │ ├──0--Parade +│ | │ ├── ... +│ | │ ├── Annotations +│ │ ├── val.txt +│ │ ├── train.txt + +``` + +After that you can train the SSD300 on WIDER by launching training with the `ssd300_wider_face.py` config or +create your own config based on the presented one. diff --git a/configs/wider_face/ssd300_wider_face.py b/configs/wider_face/ssd300_wider_face.py new file mode 100644 index 0000000000000000000000000000000000000000..53cafc1ef2ea269e3a98208de3e6fefc03a94837 --- /dev/null +++ b/configs/wider_face/ssd300_wider_face.py @@ -0,0 +1,135 @@ +# model settings +input_size = 300 +model = dict( + type='SingleStageDetector', + pretrained='open-mmlab://vgg16_caffe', + backbone=dict( + type='SSDVGG', + input_size=input_size, + depth=16, + with_last_pool=False, + ceil_mode=True, + out_indices=(3, 4), + out_feature_indices=(22, 34), + l2_norm_scale=20), + neck=None, + bbox_head=dict( + type='SSDHead', + input_size=input_size, + in_channels=(512, 1024, 512, 256, 256, 256), + num_classes=2, + anchor_strides=(8, 16, 32, 64, 100, 300), + basesize_ratio_range=(0.15, 0.9), + anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]), + target_means=(.0, .0, .0, .0), + target_stds=(0.1, 0.1, 0.2, 0.2))) +cudnn_benchmark = True +train_cfg = dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0., + ignore_iof_thr=-1, + gt_max_assign_all=False), + smoothl1_beta=1., + allowed_border=-1, + pos_weight=-1, + neg_pos_ratio=3, + debug=False) +test_cfg = dict( + nms=dict(type='nms', iou_thr=0.45), + min_bbox_size=0, + score_thr=0.02, + max_per_img=200) +# model training and testing settings +# dataset settings +dataset_type = 'WIDERFaceDataset' +data_root = 'data/WIDERFace/' +img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[1, 1, 1], to_rgb=True) +data = dict( + imgs_per_gpu=60, + workers_per_gpu=2, + train=dict( + type='RepeatDataset', + times=2, + dataset=dict( + type=dataset_type, + ann_file=[ + data_root + 'train.txt', + ], + img_prefix=[data_root + 'WIDER_train/'], + img_scale=(300, 300), + min_size=17, # throw away very small faces to improve training, + # because 300x300 is too low resolution to detect them + img_norm_cfg=img_norm_cfg, + size_divisor=None, + flip_ratio=0.5, + with_mask=False, + with_crowd=False, + with_label=True, + test_mode=False, + extra_aug=dict( + photo_metric_distortion=dict( + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18), + expand=dict( + mean=img_norm_cfg['mean'], + to_rgb=img_norm_cfg['to_rgb'], + ratio_range=(1, 4)), + random_crop=dict( + min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3)), + resize_keep_ratio=False)), + val=dict( + type=dataset_type, + ann_file=data_root + '/val.txt', + img_prefix=data_root + 'WIDER_val/', + img_scale=(300, 300), + img_norm_cfg=img_norm_cfg, + size_divisor=None, + flip_ratio=0, + with_mask=False, + with_label=False, + test_mode=True, + resize_keep_ratio=False), + test=dict( + type=dataset_type, + ann_file=data_root + '/val.txt', + img_prefix=data_root + 'WIDER_val/', + img_scale=(300, 300), + img_norm_cfg=img_norm_cfg, + size_divisor=None, + flip_ratio=0, + with_mask=False, + with_label=False, + test_mode=True, + resize_keep_ratio=False)) +# optimizer +optimizer = dict(type='SGD', lr=1e-3, momentum=0.9, weight_decay=5e-4) +optimizer_config = dict() +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=1000, + warmup_ratio=1.0 / 3, + step=[16, 20]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=1, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 24 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/ssd300_wider' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/mmdet/core/evaluation/class_names.py b/mmdet/core/evaluation/class_names.py index a8d0a3b9b2c4acf1ade04cafc16f48fd0eb743d7..87fb2399b2b5a852a34da384f7d52e610a41e0f9 100644 --- a/mmdet/core/evaluation/class_names.py +++ b/mmdet/core/evaluation/class_names.py @@ -1,6 +1,10 @@ import mmcv +def wider_face_classes(): + return ['face'] + + def voc_classes(): return [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', @@ -82,7 +86,8 @@ dataset_aliases = { 'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'], 'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'], 'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'], - 'coco': ['coco', 'mscoco', 'ms_coco'] + 'coco': ['coco', 'mscoco', 'ms_coco'], + 'wider_face': ['WIDERFaceDataset', 'wider_face', 'WDIERFace'] } diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 572b0faa4ab11a0c0f0b72253c517e39299381c0..2c1612daff14ba8de444b45ec54a9f4d67f3be00 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -2,6 +2,7 @@ from .custom import CustomDataset from .xml_style import XMLDataset from .coco import CocoDataset from .voc import VOCDataset +from .wider_face import WIDERFaceDataset from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .utils import to_tensor, random_scale, show_ann, get_dataset from .concat_dataset import ConcatDataset @@ -12,5 +13,5 @@ __all__ = [ 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler', 'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'get_dataset', 'ConcatDataset', 'RepeatDataset', - 'ExtraAugmentation' + 'ExtraAugmentation', 'WIDERFaceDataset' ] diff --git a/mmdet/datasets/wider_face.py b/mmdet/datasets/wider_face.py new file mode 100644 index 0000000000000000000000000000000000000000..ad52a952704a2c0ca4ade0ad98f53bbfd637b0ba --- /dev/null +++ b/mmdet/datasets/wider_face.py @@ -0,0 +1,37 @@ +import os.path as osp +import xml.etree.ElementTree as ET + +import mmcv + +from .xml_style import XMLDataset + + +class WIDERFaceDataset(XMLDataset): + """ + Reader for the WIDER Face dataset in PASCAL VOC format. + Conversion scripts can be found in + https://github.com/sovrasov/wider-face-pascal-voc-annotations + """ + CLASSES = ('face',) + + def __init__(self, **kwargs): + super(WIDERFaceDataset, self).__init__(**kwargs) + + def load_annotations(self, ann_file): + img_infos = [] + img_ids = mmcv.list_from_file(ann_file) + for img_id in img_ids: + filename = '{}.jpg'.format(img_id) + xml_path = osp.join(self.img_prefix, 'Annotations', + '{}.xml'.format(img_id)) + tree = ET.parse(xml_path) + root = tree.getroot() + size = root.find('size') + width = int(size.find('width').text) + height = int(size.find('height').text) + folder = root.find('folder').text + img_infos.append( + dict(id=img_id, filename=osp.join(folder, filename), + width=width, height=height)) + + return img_infos diff --git a/mmdet/datasets/xml_style.py b/mmdet/datasets/xml_style.py index 40e3374e6dc17be5c31b238acb47e93fceb56a56..e0c6ac1b6f42b62b93a73a21ddd6f4135b103910 100644 --- a/mmdet/datasets/xml_style.py +++ b/mmdet/datasets/xml_style.py @@ -9,9 +9,10 @@ from .custom import CustomDataset class XMLDataset(CustomDataset): - def __init__(self, **kwargs): + def __init__(self, min_size=None, **kwargs): super(XMLDataset, self).__init__(**kwargs) self.cat2label = {cat: i + 1 for i, cat in enumerate(self.CLASSES)} + self.min_size = min_size def load_annotations(self, ann_file): img_infos = [] @@ -50,7 +51,14 @@ class XMLDataset(CustomDataset): int(bnd_box.find('xmax').text), int(bnd_box.find('ymax').text) ] - if difficult: + ignore = False + if self.min_size: + assert not self.test_mode + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + if w < self.min_size or h < self.min_size: + ignore = True + if difficult or ignore: bboxes_ignore.append(bbox) labels_ignore.append(label) else: