Skip to content
Snippets Groups Projects
Commit 7cee8797 authored by Jon Crall's avatar Jon Crall Committed by Kai Chen
Browse files

Add in_channels kwarg to backbones. (#1475)

* Add in_channels kwarg to backbones

* explicit import in HRNet doctest
parent 4d9a5f47
No related branches found
No related tags found
No related merge requests found
......@@ -201,6 +201,7 @@ class HRNet(nn.Module):
Args:
extra (dict): detailed configuration for each stage of HRNet.
in_channels (int): Number of input image channels. Normally 3.
conv_cfg (dict): dictionary to construct and config conv layer.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
......@@ -210,12 +211,52 @@ class HRNet(nn.Module):
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmdet.models import HRNet
>>> import torch
>>> extra = dict(
>>> stage1=dict(
>>> num_modules=1,
>>> num_branches=1,
>>> block='BOTTLENECK',
>>> num_blocks=(4, ),
>>> num_channels=(64, )),
>>> stage2=dict(
>>> num_modules=1,
>>> num_branches=2,
>>> block='BASIC',
>>> num_blocks=(4, 4),
>>> num_channels=(32, 64)),
>>> stage3=dict(
>>> num_modules=4,
>>> num_branches=3,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4),
>>> num_channels=(32, 64, 128)),
>>> stage4=dict(
>>> num_modules=3,
>>> num_branches=4,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4, 4),
>>> num_channels=(32, 64, 128, 256)))
>>> self = HRNet(extra, in_channels=1)
>>> self.eval()
>>> inputs = torch.rand(1, 1, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 32, 8, 8)
(1, 64, 4, 4)
(1, 128, 2, 2)
(1, 256, 1, 1)
"""
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
def __init__(self,
extra,
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN'),
norm_eval=True,
......@@ -235,7 +276,7 @@ class HRNet(nn.Module):
self.conv1 = build_conv_layer(
self.conv_cfg,
3,
in_channels,
64,
kernel_size=3,
stride=2,
......
......@@ -335,6 +335,7 @@ class ResNet(nn.Module):
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
......@@ -378,6 +379,7 @@ class ResNet(nn.Module):
def __init__(self,
depth,
in_channels=3,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
......@@ -426,7 +428,7 @@ class ResNet(nn.Module):
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = 64
self._make_stem_layer()
self._make_stem_layer(in_channels)
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
......@@ -464,10 +466,10 @@ class ResNet(nn.Module):
def norm1(self):
return getattr(self, self.norm1_name)
def _make_stem_layer(self):
def _make_stem_layer(self, in_channels):
self.conv1 = build_conv_layer(
self.conv_cfg,
3,
in_channels,
64,
kernel_size=7,
stride=2,
......
......@@ -160,6 +160,7 @@ class ResNeXt(ResNet):
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
num_stages (int): Resnet stages, normally 4.
groups (int): Group of resnext.
base_width (int): Base width of resnext.
......
......@@ -11,6 +11,26 @@ from ..registry import BACKBONES
@BACKBONES.register_module
class SSDVGG(VGG):
"""VGG Backbone network for single-shot-detection
Args:
input_size (int): width and height of input, from {300, 512}.
depth (int): Depth of vgg, from {11, 13, 16, 19}.
out_indices (Sequence[int]): Output from which stages.
Example:
>>> self = SSDVGG(input_size=300, depth=11)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 300, 300)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 1024, 19, 19)
(1, 512, 10, 10)
(1, 256, 5, 5)
(1, 256, 3, 3)
(1, 256, 1, 1)
"""
extra_setting = {
300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256),
512: (256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128),
......@@ -24,6 +44,7 @@ class SSDVGG(VGG):
out_indices=(3, 4),
out_feature_indices=(22, 34),
l2_norm_scale=20.):
# TODO: in_channels for mmcv.VGG
super(SSDVGG, self).__init__(
depth,
with_last_pool=with_last_pool,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment