diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 0138a3ffacc0cdcc6fcd4fad8af52c14f2bb74dd..4dc4a37124fd28799547280ebbdf6c1577888526 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -162,6 +162,103 @@ pytorch [launch utility](https://pytorch.org/docs/stable/distributed_deprecated. Usually it is slow if you do not have high speed networking like infiniband. +## Useful tools + +### Analyze logs + +You can plot loss/mAP curves given a training log file. Run `pip install seaborn` first to install the dependency. + + + +```shell +python tools/analyze_logs.py plot_curve [--keys ${KEYS}] [--title ${TITLE}] [--legend ${LEGEND}] [--backend ${BACKEND}] [--style ${STYLE}] [--out ${OUT_FILE}] +``` + +Examples: + +- Plot the classification loss of some run. + +```shell +python tools/analyze_logs.py plot_curve log.json --keys loss_cls --legend loss_cls +``` + +- Plot the classification and regression loss of some run, and save the figure to a pdf. + +```shell +python tools/analyze_logs.py plot_curve log.json --keys loss_cls loss_reg --out losses.pdf +``` + +- Compare the bbox mAP of two runs in the same figure. + +```shell +python tools/analyze_logs.py plot_curve log1.json log2.json --keys bbox_mAP --legend run1 run2 +``` + +You can also compute the average training speed. + +```shell +python tools/analyze_logs.py cal_train_time ${CONFIG_FILE} [--include-outliers] +``` + +The output is expected to be like the following. + +``` +-----Analyze train time of work_dirs/some_exp/20190611_192040.log.json----- +slowest epoch 11, average time is 1.2024 +fastest epoch 1, average time is 1.1909 +time std over epochs is 0.0028 +average iter time: 1.1959 s/iter + +``` + +### Get the FLOPs and params (experimental) + +We provide a script adapted from [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch) to compute the FLOPs and params of a given model. + +```shell +python tools/get_flops.py ${CONFIG_FILE} [--shape ${INPUT_SHAPE}] +``` + +You will get the result like this. + +``` +============================== +Input shape: (3, 1280, 800) +Flops: 239.32 GMac +Params: 37.74 M +============================== +``` + +**Note**: This tool is still experimental and we do not guarantee that the number is correct. You may well use the result for simple comparisons, but double check it before you adopt it in technical reports or papers. + +(1) FLOPs are related to the input shape while parameters are not. The default input shape is (1, 3, 1280, 800). +(2) Some operators are not counted into FLOPs like GN and custom operators. +You can add support for new operators by modifying [`mmdet/utils/flops_counter.py`](mmdet/utils/flops_counter.py). +(3) The FLOPs of two-stage detectors is dependent on the number of proposals. + +### Publish a model + +Before you upload a model to AWS, you may want to +(1) convert model weights to CPU tensors, (2) delete the optimizer states and +(3) compute the hash of the checkpoint file and append the hash id to the filename. + +```shell +python tools/publish_model.py ${INPUT_FILENAME} ${OUTPUT_FILENAME} +``` + +E.g., + +```shell +python tools/publish_model.py work_dirs/faster_rcnn/latest.pth faster_rcnn_r50_fpn_1x_20190801.pth +``` + +The final output filename will be `faster_rcnn_r50_fpn_1x_20190801-{hash id}.pth`. + +### Test the robustness of detectors + +Please refer to [ROBUSTNESS_BENCHMARKING.md](ROBUSTNESS_BENCHMARKING.md). + + ## How-to ### Use my own datasets diff --git a/demo/loss_curve.png b/demo/loss_curve.png new file mode 100644 index 0000000000000000000000000000000000000000..02425551174d57ae6fecd51be7960acad84c934c Binary files /dev/null and b/demo/loss_curve.png differ diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py index 7ecbdff74c2923aab74ce5a511dc7f6457a5df86..bd878eb40c0bcfd126cf8e3a62a6ef8a1bf86cf5 100644 --- a/mmdet/models/detectors/cascade_rcnn.py +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -117,6 +117,37 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): x = self.neck(x) return x + def forward_dummy(self, img): + outs = () + # backbone + x = self.extract_feat(img) + # rpn + if self.with_rpn: + rpn_outs = self.rpn_head(x) + outs = outs + (rpn_outs, ) + proposals = torch.randn(1000, 4).cuda() + # bbox heads + rois = bbox2roi([proposals]) + if self.with_bbox: + for i in range(self.num_stages): + bbox_feats = self.bbox_roi_extractor[i]( + x[:self.bbox_roi_extractor[i].num_inputs], rois) + if self.with_shared_head: + bbox_feats = self.shared_head(bbox_feats) + cls_score, bbox_pred = self.bbox_head[i](bbox_feats) + outs = outs + (cls_score, bbox_pred) + # mask heads + if self.with_mask: + mask_rois = rois[:100] + for i in range(self.num_stages): + mask_feats = self.mask_roi_extractor[i]( + x[:self.mask_roi_extractor[i].num_inputs], mask_rois) + if self.with_shared_head: + mask_feats = self.shared_head(mask_feats) + mask_pred = self.mask_head[i](mask_feats) + outs = outs + (mask_pred, ) + return outs + def forward_train(self, img, img_meta, diff --git a/mmdet/models/detectors/double_head_rcnn.py b/mmdet/models/detectors/double_head_rcnn.py index 08f998b8aeba2f3c098332a8d169031c04f2a970..7a783353f1eba4ee551a4e9c4368a3584dd09aa0 100644 --- a/mmdet/models/detectors/double_head_rcnn.py +++ b/mmdet/models/detectors/double_head_rcnn.py @@ -12,6 +12,30 @@ class DoubleHeadRCNN(TwoStageDetector): super().__init__(**kwargs) self.reg_roi_scale_factor = reg_roi_scale_factor + def forward_dummy(self, img): + outs = () + # backbone + x = self.extract_feat(img) + # rpn + if self.with_rpn: + rpn_outs = self.rpn_head(x) + outs = outs + (rpn_outs, ) + proposals = torch.randn(1000, 4).cuda() + # bbox head + rois = bbox2roi([proposals]) + bbox_cls_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], rois) + bbox_reg_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], + rois, + roi_scale_factor=self.reg_roi_scale_factor) + if self.with_shared_head: + bbox_cls_feats = self.shared_head(bbox_cls_feats) + bbox_reg_feats = self.shared_head(bbox_reg_feats) + cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats) + outs += (cls_score, bbox_pred) + return outs + def forward_train(self, img, img_meta, diff --git a/mmdet/models/detectors/grid_rcnn.py b/mmdet/models/detectors/grid_rcnn.py index 2c321646151ce915bedd28f958d92bdc1f15196b..853242c166c07347ea13e62ad12b9a8f542f6cd2 100644 --- a/mmdet/models/detectors/grid_rcnn.py +++ b/mmdet/models/detectors/grid_rcnn.py @@ -80,6 +80,31 @@ class GridRCNN(TwoStageDetector): sampling_result.pos_bboxes = new_bboxes return sampling_results + def forward_dummy(self, img): + outs = () + # backbone + x = self.extract_feat(img) + # rpn + if self.with_rpn: + rpn_outs = self.rpn_head(x) + outs = outs + (rpn_outs, ) + proposals = torch.randn(1000, 4).cuda() + # bbox head + rois = bbox2roi([proposals]) + bbox_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], rois) + if self.with_shared_head: + bbox_feats = self.shared_head(bbox_feats) + cls_score, bbox_pred = self.bbox_head(bbox_feats) + # grid head + grid_rois = rois[:100] + grid_feats = self.grid_roi_extractor( + x[:self.grid_roi_extractor.num_inputs], grid_rois) + if self.with_shared_head: + grid_feats = self.shared_head(grid_feats) + grid_pred = self.grid_head(grid_feats) + return rpn_outs, cls_score, bbox_pred, grid_pred + def forward_train(self, img, img_meta, diff --git a/mmdet/models/detectors/htc.py b/mmdet/models/detectors/htc.py index 7135fe1d38d1c1d73952569079b097aea7f821ab..d0a70246d1c0f08961a8988bfa7521be7c622ce4 100644 --- a/mmdet/models/detectors/htc.py +++ b/mmdet/models/detectors/htc.py @@ -153,6 +153,46 @@ class HybridTaskCascade(CascadeRCNN): mask_pred = mask_head(mask_feats) return mask_pred + def forward_dummy(self, img): + outs = () + # backbone + x = self.extract_feat(img) + # rpn + if self.with_rpn: + rpn_outs = self.rpn_head(x) + outs = outs + (rpn_outs, ) + proposals = torch.randn(1000, 4).cuda() + # semantic head + if self.with_semantic: + _, semantic_feat = self.semantic_head(x) + else: + semantic_feat = None + # bbox heads + rois = bbox2roi([proposals]) + for i in range(self.num_stages): + cls_score, bbox_pred = self._bbox_forward_test( + i, x, rois, semantic_feat=semantic_feat) + outs = outs + (cls_score, bbox_pred) + # mask heads + if self.with_mask: + mask_rois = rois[:100] + mask_roi_extractor = self.mask_roi_extractor[-1] + mask_feats = mask_roi_extractor( + x[:len(mask_roi_extractor.featmap_strides)], mask_rois) + if self.with_semantic and 'mask' in self.semantic_fusion: + mask_semantic_feat = self.semantic_roi_extractor( + [semantic_feat], mask_rois) + mask_feats += mask_semantic_feat + last_feat = None + for i in range(self.num_stages): + mask_head = self.mask_head[i] + if self.mask_info_flow: + mask_pred, last_feat = mask_head(mask_feats, last_feat) + else: + mask_pred = mask_head(mask_feats) + outs = outs + (mask_pred, ) + return outs + def forward_train(self, img, img_meta, diff --git a/mmdet/models/detectors/mask_scoring_rcnn.py b/mmdet/models/detectors/mask_scoring_rcnn.py index b035f53e1683bf11f85f318b6d94c6d4730332e6..9c16ab14adb5bb5dc32e0dbcf32db47dd28b79bc 100644 --- a/mmdet/models/detectors/mask_scoring_rcnn.py +++ b/mmdet/models/detectors/mask_scoring_rcnn.py @@ -42,6 +42,9 @@ class MaskScoringRCNN(TwoStageDetector): self.mask_iou_head = builder.build_head(mask_iou_head) self.mask_iou_head.init_weights() + def forward_dummy(self, img): + raise NotImplementedError + # TODO: refactor forward_train in two stage to reduce code redundancy def forward_train(self, img, diff --git a/mmdet/models/detectors/rpn.py b/mmdet/models/detectors/rpn.py index 2f947fae474b5efe17152649f3ddcde6980d0078..c9de290fbfa46b3daeb8a062ed0f582b67dc147a 100644 --- a/mmdet/models/detectors/rpn.py +++ b/mmdet/models/detectors/rpn.py @@ -38,6 +38,11 @@ class RPN(BaseDetector, RPNTestMixin): x = self.neck(x) return x + def forward_dummy(self, img): + x = self.extract_feat(img) + rpn_outs = self.rpn_head(x) + return rpn_outs + def forward_train(self, img, img_meta, diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py index f7e0fa6f82d9fec45bf65301c902ed5f256f83d0..95873926239d741b135a9d644458115df2ada578 100644 --- a/mmdet/models/detectors/single_stage.py +++ b/mmdet/models/detectors/single_stage.py @@ -42,6 +42,11 @@ class SingleStageDetector(BaseDetector): x = self.neck(x) return x + def forward_dummy(self, img): + x = self.extract_feat(img) + outs = self.bbox_head(x) + return outs + def forward_train(self, img, img_metas, diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py index e1536b54b5a0ca3595890ce57bc335ec11ef7766..6ec1541b80d916b1d72941e3822fd39cc921e7c1 100644 --- a/mmdet/models/detectors/two_stage.py +++ b/mmdet/models/detectors/two_stage.py @@ -87,6 +87,35 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, x = self.neck(x) return x + def forward_dummy(self, img): + outs = () + # backbone + x = self.extract_feat(img) + # rpn + if self.with_rpn: + rpn_outs = self.rpn_head(x) + outs = outs + (rpn_outs, ) + proposals = torch.randn(1000, 4).cuda() + # bbox head + rois = bbox2roi([proposals]) + if self.with_bbox: + bbox_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], rois) + if self.with_shared_head: + bbox_feats = self.shared_head(bbox_feats) + cls_score, bbox_pred = self.bbox_head(bbox_feats) + outs = outs + (cls_score, bbox_pred) + # mask head + if self.with_mask: + mask_rois = rois[:100] + mask_feats = self.mask_roi_extractor( + x[:self.mask_roi_extractor.num_inputs], mask_rois) + if self.with_shared_head: + mask_feats = self.shared_head(mask_feats) + mask_pred = self.mask_head(mask_feats) + outs = outs + (mask_pred, ) + return outs + def forward_train(self, img, img_meta, diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py index c0a12443d1d45254733fc7f97cca08e11266d36c..f65e3b2fbcf1d87c6c184f4ef63e9ad40537a23e 100644 --- a/mmdet/utils/__init__.py +++ b/mmdet/utils/__init__.py @@ -1,3 +1,4 @@ +from .flops_counter import get_model_complexity_info from .registry import Registry, build_from_cfg -__all__ = ['Registry', 'build_from_cfg'] +__all__ = ['Registry', 'build_from_cfg', 'get_model_complexity_info'] diff --git a/mmdet/utils/flops_counter.py b/mmdet/utils/flops_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..3005a14fc02e2b97232c94c786489d1a9c936799 --- /dev/null +++ b/mmdet/utils/flops_counter.py @@ -0,0 +1,421 @@ +# Modified from flops-counter.pytorch by Vladislav Sovrasov +# original repo: https://github.com/sovrasov/flops-counter.pytorch + +# MIT License + +# Copyright (c) 2018 Vladislav Sovrasov + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import sys + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin +from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, + _AvgPoolNd, _MaxPoolNd) + +CONV_TYPES = (_ConvNd, ) +DECONV_TYPES = (_ConvTransposeMixin, ) +LINEAR_TYPES = (nn.Linear, ) +POOLING_TYPES = (_AvgPoolNd, _MaxPoolNd, _AdaptiveAvgPoolNd, + _AdaptiveMaxPoolNd) +RELU_TYPES = (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6) +BN_TYPES = (_BatchNorm, ) +UPSAMPLE_TYPES = (nn.Upsample, ) + +SUPPORTED_TYPES = ( + CONV_TYPES + DECONV_TYPES + LINEAR_TYPES + POOLING_TYPES + RELU_TYPES + + BN_TYPES + UPSAMPLE_TYPES) + + +def get_model_complexity_info(model, + input_res, + print_per_layer_stat=True, + as_strings=True, + input_constructor=None, + ost=sys.stdout): + assert type(input_res) is tuple + assert len(input_res) >= 2 + flops_model = add_flops_counting_methods(model) + flops_model.eval().start_flops_count() + if input_constructor: + input = input_constructor(input_res) + _ = flops_model(**input) + else: + batch = torch.ones(()).new_empty( + (1, *input_res), + dtype=next(flops_model.parameters()).dtype, + device=next(flops_model.parameters()).device) + flops_model(batch) + + if print_per_layer_stat: + print_model_with_flops(flops_model, ost=ost) + flops_count = flops_model.compute_average_flops_cost() + params_count = get_model_parameters_number(flops_model) + flops_model.stop_flops_count() + + if as_strings: + return flops_to_string(flops_count), params_to_string(params_count) + + return flops_count, params_count + + +def flops_to_string(flops, units='GMac', precision=2): + if units is None: + if flops // 10**9 > 0: + return str(round(flops / 10.**9, precision)) + ' GMac' + elif flops // 10**6 > 0: + return str(round(flops / 10.**6, precision)) + ' MMac' + elif flops // 10**3 > 0: + return str(round(flops / 10.**3, precision)) + ' KMac' + else: + return str(flops) + ' Mac' + else: + if units == 'GMac': + return str(round(flops / 10.**9, precision)) + ' ' + units + elif units == 'MMac': + return str(round(flops / 10.**6, precision)) + ' ' + units + elif units == 'KMac': + return str(round(flops / 10.**3, precision)) + ' ' + units + else: + return str(flops) + ' Mac' + + +def params_to_string(params_num): + if params_num // 10**6 > 0: + return str(round(params_num / 10**6, 2)) + ' M' + elif params_num // 10**3: + return str(round(params_num / 10**3, 2)) + ' k' + else: + return str(params_num) + + +def print_model_with_flops(model, units='GMac', precision=3, ost=sys.stdout): + total_flops = model.compute_average_flops_cost() + + def accumulate_flops(self): + if is_supported_instance(self): + return self.__flops__ / model.__batch_counter__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_flops() + return sum + + def flops_repr(self): + accumulated_flops_cost = self.accumulate_flops() + return ', '.join([ + flops_to_string( + accumulated_flops_cost, units=units, precision=precision), + '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), + self.original_extra_repr() + ]) + + def add_extra_repr(m): + m.accumulate_flops = accumulate_flops.__get__(m) + flops_extra_repr = flops_repr.__get__(m) + if m.extra_repr != flops_extra_repr: + m.original_extra_repr = m.extra_repr + m.extra_repr = flops_extra_repr + assert m.extra_repr != m.original_extra_repr + + def del_extra_repr(m): + if hasattr(m, 'original_extra_repr'): + m.extra_repr = m.original_extra_repr + del m.original_extra_repr + if hasattr(m, 'accumulate_flops'): + del m.accumulate_flops + + model.apply(add_extra_repr) + print(model, file=ost) + model.apply(del_extra_repr) + + +def get_model_parameters_number(model): + params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return params_num + + +def add_flops_counting_methods(net_main_module): + # adding additional methods to the existing module object, + # this is done this way so that each function has access to self object + net_main_module.start_flops_count = start_flops_count.__get__( + net_main_module) + net_main_module.stop_flops_count = stop_flops_count.__get__( + net_main_module) + net_main_module.reset_flops_count = reset_flops_count.__get__( + net_main_module) + net_main_module.compute_average_flops_cost = \ + compute_average_flops_cost.__get__(net_main_module) + + net_main_module.reset_flops_count() + + # Adding variables necessary for masked flops computation + net_main_module.apply(add_flops_mask_variable_or_reset) + + return net_main_module + + +def compute_average_flops_cost(self): + """ + A method that will be available after add_flops_counting_methods() is + called on a desired net object. + Returns current mean flops consumption per image. + """ + + batches_count = self.__batch_counter__ + flops_sum = 0 + for module in self.modules(): + if is_supported_instance(module): + flops_sum += module.__flops__ + + return flops_sum / batches_count + + +def start_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is + called on a desired net object. + Activates the computation of mean flops consumption per image. + Call it before you run the network. + """ + add_batch_counter_hook_function(self) + self.apply(add_flops_counter_hook_function) + + +def stop_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is + called on a desired net object. + Stops computing the mean flops consumption per image. + Call whenever you want to pause the computation. + """ + remove_batch_counter_hook_function(self) + self.apply(remove_flops_counter_hook_function) + + +def reset_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is + called on a desired net object. + Resets statistics computed so far. + """ + add_batch_counter_variables_or_reset(self) + self.apply(add_flops_counter_variable_or_reset) + + +def add_flops_mask(module, mask): + + def add_flops_mask_func(module): + if isinstance(module, torch.nn.Conv2d): + module.__mask__ = mask + + module.apply(add_flops_mask_func) + + +def remove_flops_mask(module): + module.apply(add_flops_mask_variable_or_reset) + + +def is_supported_instance(module): + if isinstance(module, SUPPORTED_TYPES): + return True + else: + return False + + +def empty_flops_counter_hook(module, input, output): + module.__flops__ += 0 + + +def upsample_flops_counter_hook(module, input, output): + output_size = output[0] + batch_size = output_size.shape[0] + output_elements_count = batch_size + for val in output_size.shape[1:]: + output_elements_count *= val + module.__flops__ += int(output_elements_count) + + +def relu_flops_counter_hook(module, input, output): + active_elements_count = output.numel() + module.__flops__ += int(active_elements_count) + + +def linear_flops_counter_hook(module, input, output): + input = input[0] + batch_size = input.shape[0] + module.__flops__ += int(batch_size * input.shape[1] * output.shape[1]) + + +def pool_flops_counter_hook(module, input, output): + input = input[0] + module.__flops__ += int(np.prod(input.shape)) + + +def bn_flops_counter_hook(module, input, output): + module.affine + input = input[0] + + batch_flops = np.prod(input.shape) + if module.affine: + batch_flops *= 2 + module.__flops__ += int(batch_flops) + + +def deconv_flops_counter_hook(conv_module, input, output): + # Can have multiple inputs, getting the first one + input = input[0] + + batch_size = input.shape[0] + input_height, input_width = input.shape[2:] + + kernel_height, kernel_width = conv_module.kernel_size + in_channels = conv_module.in_channels + out_channels = conv_module.out_channels + groups = conv_module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = ( + kernel_height * kernel_width * in_channels * filters_per_channel) + + active_elements_count = batch_size * input_height * input_width + overall_conv_flops = conv_per_position_flops * active_elements_count + bias_flops = 0 + if conv_module.bias is not None: + output_height, output_width = output.shape[2:] + bias_flops = out_channels * batch_size * output_height * output_height + overall_flops = overall_conv_flops + bias_flops + + conv_module.__flops__ += int(overall_flops) + + +def conv_flops_counter_hook(conv_module, input, output): + # Can have multiple inputs, getting the first one + input = input[0] + + batch_size = input.shape[0] + output_dims = list(output.shape[2:]) + + kernel_dims = list(conv_module.kernel_size) + in_channels = conv_module.in_channels + out_channels = conv_module.out_channels + groups = conv_module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = np.prod( + kernel_dims) * in_channels * filters_per_channel + + active_elements_count = batch_size * np.prod(output_dims) + + if conv_module.__mask__ is not None: + # (b, 1, h, w) + output_height, output_width = output.shape[2:] + flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, + output_width) + active_elements_count = flops_mask.sum() + + overall_conv_flops = conv_per_position_flops * active_elements_count + + bias_flops = 0 + + if conv_module.bias is not None: + + bias_flops = out_channels * active_elements_count + + overall_flops = overall_conv_flops + bias_flops + + conv_module.__flops__ += int(overall_flops) + + +def batch_counter_hook(module, input, output): + batch_size = 1 + if len(input) > 0: + # Can have multiple inputs, getting the first one + input = input[0] + batch_size = len(input) + else: + print('Warning! No positional inputs found for a module, ' + 'assuming batch size is 1.') + module.__batch_counter__ += batch_size + + +def add_batch_counter_variables_or_reset(module): + + module.__batch_counter__ = 0 + + +def add_batch_counter_hook_function(module): + if hasattr(module, '__batch_counter_handle__'): + return + + handle = module.register_forward_hook(batch_counter_hook) + module.__batch_counter_handle__ = handle + + +def remove_batch_counter_hook_function(module): + if hasattr(module, '__batch_counter_handle__'): + module.__batch_counter_handle__.remove() + del module.__batch_counter_handle__ + + +def add_flops_counter_variable_or_reset(module): + if is_supported_instance(module): + module.__flops__ = 0 + + +def add_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + return + + if isinstance(module, CONV_TYPES): + handle = module.register_forward_hook(conv_flops_counter_hook) + elif isinstance(module, RELU_TYPES): + handle = module.register_forward_hook(relu_flops_counter_hook) + elif isinstance(module, LINEAR_TYPES): + handle = module.register_forward_hook(linear_flops_counter_hook) + elif isinstance(module, POOLING_TYPES): + handle = module.register_forward_hook(pool_flops_counter_hook) + elif isinstance(module, BN_TYPES): + handle = module.register_forward_hook(bn_flops_counter_hook) + elif isinstance(module, UPSAMPLE_TYPES): + handle = module.register_forward_hook(upsample_flops_counter_hook) + elif isinstance(module, DECONV_TYPES): + handle = module.register_forward_hook(deconv_flops_counter_hook) + else: + handle = module.register_forward_hook(empty_flops_counter_hook) + module.__flops_handle__ = handle + + +def remove_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + module.__flops_handle__.remove() + del module.__flops_handle__ + + +# --- Masked flops counting +# Also being run in the initialization +def add_flops_mask_variable_or_reset(module): + if is_supported_instance(module): + module.__mask__ = None diff --git a/tools/get_flops.py b/tools/get_flops.py new file mode 100644 index 0000000000000000000000000000000000000000..e64bac6dc5ea2a0011991c6352e716971569a17d --- /dev/null +++ b/tools/get_flops.py @@ -0,0 +1,52 @@ +import argparse + +from mmcv import Config + +from mmdet.models import build_detector +from mmdet.utils import get_model_complexity_info + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a detector') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[1280, 800], + help='input image size') + args = parser.parse_args() + return args + + +def main(): + + args = parse_args() + + if len(args.shape) == 1: + input_shape = (3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = (3, ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + cfg = Config.fromfile(args.config) + model = build_detector( + cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda() + model.eval() + + if hasattr(model, 'forward_dummy'): + model.forward = model.forward_dummy + else: + raise NotImplementedError( + 'FLOPs counter is currently not currently supported with {}'. + format(model.__class__.__name__)) + + flops, params = get_model_complexity_info(model, input_shape) + split_line = '=' * 30 + print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( + split_line, input_shape, flops, params)) + + +if __name__ == '__main__': + main()