From f83c8638920fd710d12a1d904a2ee957c58adda3 Mon Sep 17 00:00:00 2001 From: Jon Crall <erotemic@gmail.com> Date: Tue, 8 Oct 2019 12:21:16 -0400 Subject: [PATCH] Add docs FCOS/Retina, OHEM, Scale, make_optimizer (#1505) --- mmdet/apis/train.py | 6 ++++++ mmdet/core/bbox/samplers/ohem_sampler.py | 6 ++++++ mmdet/models/anchor_heads/fcos_head.py | 16 ++++++++++++++++ mmdet/models/anchor_heads/retina_head.py | 20 ++++++++++++++++++++ mmdet/models/utils/scale.py | 3 +++ 5 files changed, 51 insertions(+) diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index 35ad34c..fee47d6 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -82,6 +82,12 @@ def build_optimizer(model, optimizer_cfg): Returns: torch.optim.Optimizer: The initialized optimizer. + + Example: + >>> model = torch.nn.modules.Conv1d(1, 1, 1) + >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, + >>> weight_decay=0.0001) + >>> optimizer = build_optimizer(model, optimizer_cfg) """ if hasattr(model, 'module'): model = model.module diff --git a/mmdet/core/bbox/samplers/ohem_sampler.py b/mmdet/core/bbox/samplers/ohem_sampler.py index 2500f31..3701d83 100644 --- a/mmdet/core/bbox/samplers/ohem_sampler.py +++ b/mmdet/core/bbox/samplers/ohem_sampler.py @@ -5,6 +5,12 @@ from .base_sampler import BaseSampler class OHEMSampler(BaseSampler): + """ + Online Hard Example Mining Sampler described in [1]_. + + References: + .. [1] https://arxiv.org/pdf/1604.03540.pdf + """ def __init__(self, num, diff --git a/mmdet/models/anchor_heads/fcos_head.py b/mmdet/models/anchor_heads/fcos_head.py index c01e4ea..a8c2cd4 100644 --- a/mmdet/models/anchor_heads/fcos_head.py +++ b/mmdet/models/anchor_heads/fcos_head.py @@ -12,6 +12,22 @@ INF = 1e8 @HEADS.register_module class FCOSHead(nn.Module): + """ + Fully Convolutional One-Stage Object Detection head from [1]_. + + The FCOS head does not use anchor boxes. Instead bounding boxes are + predicted at each pixel and a centerness measure is used to supress + low-quality predictions. + + References: + .. [1] https://arxiv.org/abs/1904.01355 + + Example: + >>> self = FCOSHead(11, 7) + >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]] + >>> cls_score, bbox_pred, centerness = self.forward(feats) + >>> assert len(cls_score) == len(self.scales) + """ def __init__(self, num_classes, diff --git a/mmdet/models/anchor_heads/retina_head.py b/mmdet/models/anchor_heads/retina_head.py index 045db70..e3b8143 100644 --- a/mmdet/models/anchor_heads/retina_head.py +++ b/mmdet/models/anchor_heads/retina_head.py @@ -9,6 +9,26 @@ from .anchor_head import AnchorHead @HEADS.register_module class RetinaHead(AnchorHead): + """ + An anchor-based head used in [1]_. + + The head contains two subnetworks. The first classifies anchor boxes and + the second regresses deltas for the anchors. + + References: + .. [1] https://arxiv.org/pdf/1708.02002.pdf + + Example: + >>> import torch + >>> self = RetinaHead(11, 7) + >>> x = torch.rand(1, 7, 32, 32) + >>> cls_score, bbox_pred = self.forward_single(x) + >>> # Each anchor predicts a score for each class except background + >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors + >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors + >>> assert cls_per_anchor == (self.num_classes - 1) + >>> assert box_per_anchor == 4 + """ def __init__(self, num_classes, diff --git a/mmdet/models/utils/scale.py b/mmdet/models/utils/scale.py index 68c37cd..2461af8 100644 --- a/mmdet/models/utils/scale.py +++ b/mmdet/models/utils/scale.py @@ -3,6 +3,9 @@ import torch.nn as nn class Scale(nn.Module): + """ + A learnable scale parameter + """ def __init__(self, scale=1.0): super(Scale, self).__init__() -- GitLab