From 4d9a5f47c27c60fb9f3b9a1866507afba3fcba38 Mon Sep 17 00:00:00 2001
From: Jon Crall <erotemic@gmail.com>
Date: Thu, 3 Oct 2019 01:26:00 -0400
Subject: [PATCH] Add two doctest examples for resnet and resnext (#1474)

* Add doctest example for resnet and resnext

* Explicit import in doctests

* Fix missing torch import in doctest
---
 mmdet/models/backbones/resnet.py  | 14 ++++++++++++++
 mmdet/models/backbones/resnext.py | 14 ++++++++++++++
 2 files changed, 28 insertions(+)

diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index d87b736..2967dd9 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -352,6 +352,20 @@ class ResNet(nn.Module):
             memory while slowing down the training speed.
         zero_init_residual (bool): whether to use zero init for last norm layer
             in resblocks to let them behave as identity.
+
+    Example:
+        >>> from mmdet.models import ResNet
+        >>> import torch
+        >>> self = ResNet(depth=18)
+        >>> self.eval()
+        >>> inputs = torch.rand(1, 3, 32, 32)
+        >>> level_outputs = self.forward(inputs)
+        >>> for level_out in level_outputs:
+        ...     print(tuple(level_out.shape))
+        (1, 64, 8, 8)
+        (1, 128, 4, 4)
+        (1, 256, 2, 2)
+        (1, 512, 1, 1)
     """
 
     arch_settings = {
diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py
index c5feaa4..be28976 100644
--- a/mmdet/models/backbones/resnext.py
+++ b/mmdet/models/backbones/resnext.py
@@ -179,6 +179,20 @@ class ResNeXt(ResNet):
             memory while slowing down the training speed.
         zero_init_residual (bool): whether to use zero init for last norm layer
             in resblocks to let them behave as identity.
+
+    Example:
+        >>> from mmdet.models import ResNeXt
+        >>> import torch
+        >>> self = ResNeXt(depth=50)
+        >>> self.eval()
+        >>> inputs = torch.rand(1, 3, 32, 32)
+        >>> level_outputs = self.forward(inputs)
+        >>> for level_out in level_outputs:
+        ...     print(tuple(level_out.shape))
+        (1, 256, 8, 8)
+        (1, 512, 4, 4)
+        (1, 1024, 2, 2)
+        (1, 2048, 1, 1)
     """
 
     arch_settings = {
-- 
GitLab