diff --git a/mmdet/core/__init__.py b/mmdet/core/__init__.py index 1eb03f76acdfaff65c96e5316e3b4898b7a7af6a..e1495428889df9ffb57331b0bbe17b2f9de24185 100644 --- a/mmdet/core/__init__.py +++ b/mmdet/core/__init__.py @@ -1,6 +1,7 @@ from .rpn_ops import * from .bbox_ops import * from .mask_ops import * +from .targets import * from .losses import * from .eval import * from .parallel import * diff --git a/mmdet/core/targets/__init__.py b/mmdet/core/targets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f32c296c48744edef287a62bafa88d6180caa48a --- /dev/null +++ b/mmdet/core/targets/__init__.py @@ -0,0 +1 @@ +from .retina_target import retina_target diff --git a/mmdet/core/targets/retina_target.py b/mmdet/core/targets/retina_target.py new file mode 100644 index 0000000000000000000000000000000000000000..1b42736e2f61979ddb932f3b31d8c4c53fcffa60 --- /dev/null +++ b/mmdet/core/targets/retina_target.py @@ -0,0 +1,168 @@ +import torch + +from ..bbox_ops import bbox_assign, bbox2delta +from ..utils import multi_apply + + +def retina_target(anchor_list, valid_flag_list, gt_bboxes_list, gt_labels_list, + img_metas, target_means, target_stds, cls_out_channels, cfg): + """Compute regression and classification targets for anchors. + + Args: + anchor_list (list[list]): Multi level anchors of each image. + valid_flag_list (list[list]): Multi level valid flags of each image. + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + target_means (Iterable): Mean value of regression targets. + target_stds (Iterable): Std value of regression targets. + cfg (dict): RPN train configs. + + Returns: + tuple + """ + num_imgs = len(img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, + pos_inds_list, neg_inds_list) = multi_apply( + retina_target_single, + anchor_list, + valid_flag_list, + gt_bboxes_list, + gt_labels_list, + img_metas, + target_means=target_means, + target_stds=target_stds, + cls_out_channels=cls_out_channels, + cfg=cfg) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + # sampled anchors of all images + num_pos_samples = sum([ + max(pos_inds.numel(), 1) + for pos_inds, neg_inds in zip(pos_inds_list, neg_inds_list) + ]) + # split targets to a list w.r.t. multiple levels + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors) + bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors) + return (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_pos_samples) + + +def images_to_levels(target, num_level_anchors): + """Convert targets by image to targets by feature level. + + [target_img0, target_img1] -> [target_level0, target_level1, ...] + """ + target = torch.stack(target, 0) + level_targets = [] + start = 0 + for n in num_level_anchors: + end = start + n + level_targets.append(target[:, start:end].squeeze(0)) + start = end + return level_targets + + +def retina_target_single(flat_anchors, valid_flags, gt_bboxes, gt_labels, + img_meta, target_means, target_stds, cls_out_channels, + cfg): + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + cfg.allowed_border) + if not inside_flags.any(): + return (None, ) * 6 + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign( + anchors, + gt_bboxes, + pos_iou_thr=cfg.pos_iou_thr, + neg_iou_thr=cfg.neg_iou_thr, + min_pos_iou=cfg.min_pos_iou) + pos_inds = torch.nonzero(assigned_gt_inds > 0) + neg_inds = torch.nonzero(assigned_gt_inds == 0) + + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + labels = torch.zeros_like(assigned_gt_inds) + label_weights = torch.zeros_like(assigned_gt_inds, dtype=anchors.dtype) + + if len(pos_inds) > 0: + pos_inds = pos_inds.squeeze(1).unique() + pos_anchors = anchors[pos_inds, :] + pos_gt_bbox = gt_bboxes[assigned_gt_inds[pos_inds] - 1, :] + pos_bbox_targets = bbox2delta(pos_anchors, pos_gt_bbox, target_means, + target_stds) + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - 1] + if cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = cfg.pos_weight + if len(neg_inds) > 0: + neg_inds = neg_inds.squeeze(1).unique() + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + num_total_anchors = flat_anchors.size(0) + labels = unmap(labels, num_total_anchors, inside_flags) + label_weights = unmap(label_weights, num_total_anchors, inside_flags) + labels, label_weights = expand_binary_labels(labels, label_weights, + cls_out_channels) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds) + + +def expand_binary_labels(labels, label_weights, cls_out_channels): + bin_labels = labels.new_full( + (labels.size(0), cls_out_channels), 0, dtype=torch.float32) + inds = torch.nonzero(labels >= 1).squeeze() + if inds.numel() > 0: + bin_labels[inds, labels[inds] - 1] = 1 + bin_label_weights = label_weights.view(-1, 1).expand( + label_weights.size(0), cls_out_channels) + return bin_labels, bin_label_weights + + +def anchor_inside_flags(flat_anchors, valid_flags, img_shape, + allowed_border=0): + img_h, img_w = img_shape[:2] + if allowed_border >= 0: + inside_flags = valid_flags & \ + (flat_anchors[:, 0] >= -allowed_border) & \ + (flat_anchors[:, 1] >= -allowed_border) & \ + (flat_anchors[:, 2] < img_w + allowed_border) & \ + (flat_anchors[:, 3] < img_h + allowed_border) + else: + inside_flags = valid_flags + return inside_flags + + +def unmap(data, count, inds, fill=0): + """ Unmap a subset of item (data) back to the original set of items (of + size count) """ + if data.dim() == 1: + ret = data.new_full((count, ), fill) + ret[inds] = data + else: + new_size = (count, ) + data.size()[1:] + ret = data.new_full(new_size, fill) + ret[inds, :] = data + return ret diff --git a/mmdet/models/builder.py b/mmdet/models/builder.py index bdf0ac3d16f9aadb194f944b3f7c4dd1a741e8cd..340a2a33a1e4e96e6faf004a15368de6b5e7f08c 100644 --- a/mmdet/models/builder.py +++ b/mmdet/models/builder.py @@ -2,11 +2,12 @@ from mmcv.runner import obj_from_dict from torch import nn from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads, - mask_heads, detectors) + mask_heads, single_stage_heads, detectors) __all__ = [ 'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor', - 'build_bbox_head', 'build_mask_head', 'build_detector' + 'build_bbox_head', 'build_mask_head', 'build_single_stage_head', + 'build_detector' ] @@ -47,5 +48,9 @@ def build_mask_head(cfg): return build(cfg, mask_heads) +def build_single_stage_head(cfg): + return build(cfg, single_stage_heads) + + def build_detector(cfg, train_cfg=None, test_cfg=None): return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index b8914c1e5d3c834a1373b2a2e8360183a41de4da..d08bf99a223726f1bbb5f872a50e74f8e9d5ae1a 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -2,5 +2,6 @@ from .base import BaseDetector from .rpn import RPN from .faster_rcnn import FasterRCNN from .mask_rcnn import MaskRCNN +from .retina_net import RetinaNet -__all__ = ['BaseDetector', 'RPN', 'FasterRCNN', 'MaskRCNN'] +__all__ = ['BaseDetector', 'RPN', 'FasterRCNN', 'MaskRCNN', 'RetinaNet'] diff --git a/mmdet/models/detectors/retina_net.py b/mmdet/models/detectors/retina_net.py new file mode 100644 index 0000000000000000000000000000000000000000..da7ede8682836496b84423f983e3da6c79209d92 --- /dev/null +++ b/mmdet/models/detectors/retina_net.py @@ -0,0 +1,14 @@ +from .single_stage import SingleStageDetector + + +class RetinaNet(SingleStageDetector): + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(RetinaNet, self).__init__(backbone, neck, bbox_head, train_cfg, + test_cfg, pretrained) diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..95075ed2fffccc449d0a7ecdcbe88cc167871a8d --- /dev/null +++ b/mmdet/models/detectors/single_stage.py @@ -0,0 +1,62 @@ +import torch.nn as nn + +from .base import BaseDetector +from .. import builder +from mmdet.core import bbox2result + + +class SingleStageDetector(BaseDetector): + + def __init__(self, + backbone, + neck=None, + bbox_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(SingleStageDetector, self).__init__() + self.backbone = builder.build_backbone(backbone) + if neck is not None: + self.neck = builder.build_neck(neck) + self.bbox_head = builder.build_single_stage_head(bbox_head) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.init_weights(pretrained=pretrained) + + def init_weights(self, pretrained=None): + super(SingleStageDetector, self).init_weights(pretrained) + self.backbone.init_weights(pretrained=pretrained) + if self.with_neck: + if isinstance(self.neck, nn.Sequential): + for m in self.neck: + m.init_weights() + else: + self.neck.init_weights() + self.bbox_head.init_weights() + + def extract_feat(self, img): + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def forward_train(self, img, img_metas, gt_bboxes, gt_labels): + x = self.extract_feat(img) + outs = self.bbox_head(x) + loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, self.train_cfg) + losses = self.bbox_head.loss(*loss_inputs) + return losses + + def simple_test(self, img, img_meta, rescale=False): + x = self.extract_feat(img) + outs = self.bbox_head(x) + bbox_inputs = outs + (img_meta, self.test_cfg, rescale) + bbox_list = self.bbox_head.get_det_bboxes(*bbox_inputs) + bbox_results = [ + bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) + for det_bboxes, det_labels in bbox_list + ] + return bbox_results[0] + + def aug_test(self, imgs, img_metas, rescale=False): + raise NotImplementedError diff --git a/mmdet/models/single_stage_heads/__init__.py b/mmdet/models/single_stage_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..265b0626bf1f4362426d71454870af0bb7395729 --- /dev/null +++ b/mmdet/models/single_stage_heads/__init__.py @@ -0,0 +1,3 @@ +from .retina_head import RetinaHead + +__all__ = ['RetinaHead'] diff --git a/mmdet/models/single_stage_heads/retina_head.py b/mmdet/models/single_stage_heads/retina_head.py new file mode 100644 index 0000000000000000000000000000000000000000..55e5e94d85cf2551576953754cb2e419f672b00e --- /dev/null +++ b/mmdet/models/single_stage_heads/retina_head.py @@ -0,0 +1,264 @@ +from __future__ import division + +import numpy as np +import torch +import torch.nn as nn + +from mmdet.core import (AnchorGenerator, multi_apply, delta2bbox, + weighted_smoothl1, weighted_sigmoid_focal_loss, + multiclass_nms, retina_target) +from ..utils import normal_init, bias_init_with_prob + + +class RetinaHead(nn.Module): + """Head of RetinaNet. + + / cls_convs - retina_cls (3x3 conv) + input - + \ reg_convs - retina_reg (3x3 conv) + + Args: + in_channels (int): Number of channels in the input feature map. + num_classes (int): Class number (including background). + stacked_convs (int): Number of convolutional layers added for cls and + reg branch. + feat_channels (int): Number of channels for the RPN feature map. + """ + + def __init__(self, + in_channels, + num_classes, + stacked_convs=4, + feat_channels=256, + scales_per_octave=3, + anchor_scale=4, + anchor_ratios=[1.0, 2.0, 0.5], + anchor_strides=[8, 16, 32, 64, 128], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]): + super(RetinaHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + self.scales_per_octave = scales_per_octave + self.anchor_scale = anchor_scale + self.anchor_strides = anchor_strides + self.anchor_ratios = anchor_ratios + self.target_means = target_means + self.target_stds = target_stds + self.anchor_generators = [] + for anchor_stride in self.anchor_strides: + octave_scales = np.array([ + 2**(octave / float(scales_per_octave)) + for octave in range(scales_per_octave) + ]) + octave_scales = octave_scales * anchor_scale + self.anchor_generators.append( + AnchorGenerator(anchor_stride, octave_scales, anchor_ratios)) + self.relu = nn.ReLU(inplace=True) + self.num_anchors = int( + len(self.anchor_ratios) * self.scales_per_octave) + self.cls_out_channels = self.num_classes - 1 + self.bbox_pred_dim = 4 + + self.stacked_convs = stacked_convs + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = in_channels if i == 0 else feat_channels + self.cls_convs.append( + nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1)) + self.reg_convs.append( + nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1)) + self.retina_cls = nn.Conv2d( + feat_channels, + self.num_anchors * self.cls_out_channels, + 3, + stride=1, + padding=1) + self.retina_reg = nn.Conv2d( + feat_channels, + self.num_anchors * self.bbox_pred_dim, + 3, + stride=1, + padding=1) + self.debug_imgs = None + + def init_weights(self): + for m in self.cls_convs: + normal_init(m, std=0.01) + for m in self.reg_convs: + normal_init(m, std=0.01) + bias_cls = bias_init_with_prob(0.01) + normal_init(self.retina_cls, std=0.01, bias=bias_cls) + normal_init(self.retina_reg, std=0.01) + + def forward_single(self, x): + cls_feat = x + reg_feat = x + for cls_conv in self.cls_convs: + cls_feat = self.relu(cls_conv(cls_feat)) + for reg_conv in self.reg_convs: + reg_feat = self.relu(reg_conv(reg_feat)) + cls_score = self.retina_cls(cls_feat) + bbox_pred = self.retina_reg(reg_feat) + return cls_score, bbox_pred + + def forward(self, feats): + return multi_apply(self.forward_single, feats) + + def get_anchors(self, featmap_sizes, img_metas): + """Get anchors according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + img_metas (list[dict]): Image meta info. + + Returns: + tuple: anchors of each image, valid flags of each image + """ + num_imgs = len(img_metas) + num_levels = len(featmap_sizes) + + # since feature map sizes of all images are the same, we only compute + # anchors for one time + multi_level_anchors = [] + for i in range(num_levels): + anchors = self.anchor_generators[i].grid_anchors( + featmap_sizes[i], self.anchor_strides[i]) + multi_level_anchors.append(anchors) + anchor_list = [multi_level_anchors for _ in range(num_imgs)] + + # for each image, we compute valid flags of multi level anchors + valid_flag_list = [] + for img_id, img_meta in enumerate(img_metas): + multi_level_flags = [] + for i in range(num_levels): + anchor_stride = self.anchor_strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w, _ = img_meta['pad_shape'] + valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h) + valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w) + flags = self.anchor_generators[i].valid_flags( + (feat_h, feat_w), (valid_feat_h, valid_feat_w)) + multi_level_flags.append(flags) + valid_flag_list.append(multi_level_flags) + + return anchor_list, valid_flag_list + + def loss_single(self, cls_score, bbox_pred, labels, label_weights, + bbox_targets, bbox_weights, num_pos_samples, cfg): + # classification loss + labels = labels.contiguous().view(-1, self.cls_out_channels) + label_weights = label_weights.contiguous().view( + -1, self.cls_out_channels) + cls_score = cls_score.permute(0, 2, 3, 1).contiguous().view( + -1, self.cls_out_channels) + loss_cls = weighted_sigmoid_focal_loss( + cls_score, + labels, + label_weights, + cfg.gamma, + cfg.alpha, + avg_factor=num_pos_samples) + # regression loss + bbox_targets = bbox_targets.contiguous().view(-1, 4) + bbox_weights = bbox_weights.contiguous().view(-1, 4) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).contiguous().view(-1, 4) + loss_reg = weighted_smoothl1( + bbox_pred, + bbox_targets, + bbox_weights, + beta=cfg.smoothl1_beta, + avg_factor=num_pos_samples) + return loss_cls, loss_reg + + def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_shapes, + cfg): + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == len(self.anchor_generators) + + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_shapes) + cls_reg_targets = retina_target( + anchor_list, valid_flag_list, gt_bboxes, gt_labels, img_shapes, + self.target_means, self.target_stds, self.cls_out_channels, cfg) + if cls_reg_targets is None: + return None + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + num_pos_samples) = cls_reg_targets + + losses_cls, losses_reg = multi_apply( + self.loss_single, + cls_scores, + bbox_preds, + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + num_pos_samples=num_pos_samples, + cfg=cfg) + return dict(loss_cls=losses_cls, loss_reg=losses_reg) + + def get_det_bboxes(self, + cls_scores, + bbox_preds, + img_metas, + cfg, + rescale=False): + assert len(cls_scores) == len(bbox_preds) + num_levels = len(cls_scores) + + mlvl_anchors = [ + self.anchor_generators[i].grid_anchors(cls_scores[i].size()[-2:], + self.anchor_strides[i]) + for i in range(num_levels) + ] + + result_list = [] + for img_id in range(len(img_metas)): + cls_score_list = [ + cls_scores[i][img_id].detach() for i in range(num_levels) + ] + bbox_pred_list = [ + bbox_preds[i][img_id].detach() for i in range(num_levels) + ] + img_shape = img_metas[img_id]['img_shape'] + scale_factor = img_metas[img_id]['scale_factor'] + results = self._get_det_bboxes_single( + cls_score_list, bbox_pred_list, mlvl_anchors, img_shape, + scale_factor, cfg, rescale) + result_list.append(results) + return result_list + + def _get_det_bboxes_single(self, + cls_scores, + bbox_preds, + mlvl_anchors, + img_shape, + scale_factor, + cfg, + rescale=False): + assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) + mlvl_proposals = [] + mlvl_scores = [] + for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds, + mlvl_anchors): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + cls_score = cls_score.permute(1, 2, 0).contiguous().view( + -1, self.cls_out_channels) + scores = cls_score.sigmoid() + bbox_pred = bbox_pred.permute(1, 2, 0).contiguous().view(-1, 4) + proposals = delta2bbox(anchors, bbox_pred, self.target_means, + self.target_stds, img_shape) + mlvl_proposals.append(proposals) + mlvl_scores.append(scores) + mlvl_proposals = torch.cat(mlvl_proposals) + if rescale: + mlvl_proposals /= scale_factor + mlvl_scores = torch.cat(mlvl_scores) + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) + det_bboxes, det_labels = multiclass_nms(mlvl_proposals, mlvl_scores, + cfg.score_thr, cfg.nms_thr, + cfg.max_per_img) + return det_bboxes, det_labels diff --git a/mmdet/models/utils/weight_init.py b/mmdet/models/utils/weight_init.py index 2e9b13b4fbc17d6d1986da876108c1a813190c2d..17d49880fd867e9a8727776ef5a9fcbd6b01599f 100644 --- a/mmdet/models/utils/weight_init.py +++ b/mmdet/models/utils/weight_init.py @@ -1,3 +1,4 @@ +import numpy as np import torch.nn as nn @@ -37,3 +38,9 @@ def kaiming_init(module, module.weight, mode=mode, nonlinearity=nonlinearity) if hasattr(module, 'bias'): nn.init.constant_(module.bias, bias) + + +def bias_init_with_prob(prior_prob): + """ initialize conv/fc bias value according to giving probablity""" + bias_init = float(-np.log((1 - prior_prob) / prior_prob)) + return bias_init diff --git a/tools/configs/r50_retinanet_1x.py b/tools/configs/r50_retinanet_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..39abcaa34dcde69dbcf23fe88b2da21330bd4fb6 --- /dev/null +++ b/tools/configs/r50_retinanet_1x.py @@ -0,0 +1,117 @@ +# model settings +model = dict( + type='RetinaNet', + pretrained='modelzoo://resnet50', + backbone=dict( + type='resnet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs=True, + num_outs=5), + bbox_head=dict( + type='RetinaHead', + num_classes=81, + in_channels=256, + stacked_convs=4, + feat_channels=256, + scales_per_octave=3, + anchor_scale=4, + anchor_ratios=[1.0, 2.0, 0.5], + anchor_strides=[8, 16, 32, 64, 128], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0])) +# training and testing settings +train_cfg = dict( + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0.4, + smoothl1_beta=0.11, + gamma=2.0, + alpha=0.25, + allowed_border=-1, + pos_weight=-1, + debug=False) +test_cfg = dict( + nms_pre=1000, + nms_thr=0.5, + min_bbox_size=0, + score_thr=0.05, + max_per_img=100) +# dataset settings +dataset_type = 'CocoDataset' +data_root = '../data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0.5, + with_mask=False, + with_crowd=False, + with_label=True), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_crowd=False, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_crowd=False, + with_label=False, + test_mode=True)) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=20, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +device_ids = range(8) +dist_params = dict(backend='gloo') +log_level = 'INFO' +work_dir = './work_dirs/fpn_retinanet_r50_1x' +load_from = None +resume_from = None +workflow = [('train', 1)]