From f83c8638920fd710d12a1d904a2ee957c58adda3 Mon Sep 17 00:00:00 2001
From: Jon Crall <erotemic@gmail.com>
Date: Tue, 8 Oct 2019 12:21:16 -0400
Subject: [PATCH] Add docs FCOS/Retina, OHEM, Scale, make_optimizer (#1505)

---
 mmdet/apis/train.py                      |  6 ++++++
 mmdet/core/bbox/samplers/ohem_sampler.py |  6 ++++++
 mmdet/models/anchor_heads/fcos_head.py   | 16 ++++++++++++++++
 mmdet/models/anchor_heads/retina_head.py | 20 ++++++++++++++++++++
 mmdet/models/utils/scale.py              |  3 +++
 5 files changed, 51 insertions(+)

diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
index 35ad34c..fee47d6 100644
--- a/mmdet/apis/train.py
+++ b/mmdet/apis/train.py
@@ -82,6 +82,12 @@ def build_optimizer(model, optimizer_cfg):
 
     Returns:
         torch.optim.Optimizer: The initialized optimizer.
+
+    Example:
+        >>> model = torch.nn.modules.Conv1d(1, 1, 1)
+        >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
+        >>>                      weight_decay=0.0001)
+        >>> optimizer = build_optimizer(model, optimizer_cfg)
     """
     if hasattr(model, 'module'):
         model = model.module
diff --git a/mmdet/core/bbox/samplers/ohem_sampler.py b/mmdet/core/bbox/samplers/ohem_sampler.py
index 2500f31..3701d83 100644
--- a/mmdet/core/bbox/samplers/ohem_sampler.py
+++ b/mmdet/core/bbox/samplers/ohem_sampler.py
@@ -5,6 +5,12 @@ from .base_sampler import BaseSampler
 
 
 class OHEMSampler(BaseSampler):
+    """
+    Online Hard Example Mining Sampler described in [1]_.
+
+    References:
+        .. [1] https://arxiv.org/pdf/1604.03540.pdf
+    """
 
     def __init__(self,
                  num,
diff --git a/mmdet/models/anchor_heads/fcos_head.py b/mmdet/models/anchor_heads/fcos_head.py
index c01e4ea..a8c2cd4 100644
--- a/mmdet/models/anchor_heads/fcos_head.py
+++ b/mmdet/models/anchor_heads/fcos_head.py
@@ -12,6 +12,22 @@ INF = 1e8
 
 @HEADS.register_module
 class FCOSHead(nn.Module):
+    """
+    Fully Convolutional One-Stage Object Detection head from [1]_.
+
+    The FCOS head does not use anchor boxes. Instead bounding boxes are
+    predicted at each pixel and a centerness measure is used to supress
+    low-quality predictions.
+
+    References:
+        .. [1] https://arxiv.org/abs/1904.01355
+
+    Example:
+        >>> self = FCOSHead(11, 7)
+        >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
+        >>> cls_score, bbox_pred, centerness = self.forward(feats)
+        >>> assert len(cls_score) == len(self.scales)
+    """
 
     def __init__(self,
                  num_classes,
diff --git a/mmdet/models/anchor_heads/retina_head.py b/mmdet/models/anchor_heads/retina_head.py
index 045db70..e3b8143 100644
--- a/mmdet/models/anchor_heads/retina_head.py
+++ b/mmdet/models/anchor_heads/retina_head.py
@@ -9,6 +9,26 @@ from .anchor_head import AnchorHead
 
 @HEADS.register_module
 class RetinaHead(AnchorHead):
+    """
+    An anchor-based head used in [1]_.
+
+    The head contains two subnetworks. The first classifies anchor boxes and
+    the second regresses deltas for the anchors.
+
+    References:
+        .. [1]  https://arxiv.org/pdf/1708.02002.pdf
+
+    Example:
+        >>> import torch
+        >>> self = RetinaHead(11, 7)
+        >>> x = torch.rand(1, 7, 32, 32)
+        >>> cls_score, bbox_pred = self.forward_single(x)
+        >>> # Each anchor predicts a score for each class except background
+        >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors
+        >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors
+        >>> assert cls_per_anchor == (self.num_classes - 1)
+        >>> assert box_per_anchor == 4
+    """
 
     def __init__(self,
                  num_classes,
diff --git a/mmdet/models/utils/scale.py b/mmdet/models/utils/scale.py
index 68c37cd..2461af8 100644
--- a/mmdet/models/utils/scale.py
+++ b/mmdet/models/utils/scale.py
@@ -3,6 +3,9 @@ import torch.nn as nn
 
 
 class Scale(nn.Module):
+    """
+    A learnable scale parameter
+    """
 
     def __init__(self, scale=1.0):
         super(Scale, self).__init__()
-- 
GitLab