diff --git a/.isort.cfg b/.isort.cfg index 2186a18b54db86232032e338498a2d723b2cc3ab..e790e3ee0ce0b21dda62f56fc5072b99ff05ab7c 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -3,6 +3,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmdet -known_third_party = Cython,albumentations,cv2,imagecorruptions,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision +known_third_party = Cython,albumentations,asynctest,cv2,imagecorruptions,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/mmdet/utils/flops_counter.py b/mmdet/utils/flops_counter.py index 5d9cdfce82b0e63c55eb08419a01c924127e9e2e..df2163fd7107e17489eea56204a713780ef59376 100644 --- a/mmdet/utils/flops_counter.py +++ b/mmdet/utils/flops_counter.py @@ -33,19 +33,6 @@ from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd) -CONV_TYPES = (_ConvNd, ) -DECONV_TYPES = (_ConvTransposeMixin, ) -LINEAR_TYPES = (nn.Linear, ) -POOLING_TYPES = (_AvgPoolNd, _MaxPoolNd, _AdaptiveAvgPoolNd, - _AdaptiveMaxPoolNd) -RELU_TYPES = (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6) -BN_TYPES = (_BatchNorm, ) -UPSAMPLE_TYPES = (nn.Upsample, ) - -SUPPORTED_TYPES = ( - CONV_TYPES + DECONV_TYPES + LINEAR_TYPES + POOLING_TYPES + RELU_TYPES + - BN_TYPES + UPSAMPLE_TYPES) - def get_model_complexity_info(model, input_res, @@ -249,10 +236,10 @@ def remove_flops_mask(module): def is_supported_instance(module): - if isinstance(module, SUPPORTED_TYPES): - return True - else: - return False + for mod in hook_mapping: + if issubclass(type(module), mod): + return True + return False def empty_flops_counter_hook(module, input, output): @@ -285,7 +272,6 @@ def pool_flops_counter_hook(module, input, output): def bn_flops_counter_hook(module, input, output): - module.affine input = input[0] batch_flops = np.prod(input.shape) @@ -294,6 +280,17 @@ def bn_flops_counter_hook(module, input, output): module.__flops__ += int(batch_flops) +def gn_flops_counter_hook(module, input, output): + elems = np.prod(input[0].shape) + # there is no precise FLOPs estimation of computing mean and variance, + # and we just set it 2 * elems: half muladds for computing + # means and half for computing vars + batch_flops = 3 * elems + if module.affine: + batch_flops += elems + module.__flops__ += int(batch_flops) + + def deconv_flops_counter_hook(conv_module, input, output): # Can have multiple inputs, getting the first one input = input[0] @@ -359,6 +356,32 @@ def conv_flops_counter_hook(conv_module, input, output): conv_module.__flops__ += int(overall_flops) +hook_mapping = { + # conv + _ConvNd: conv_flops_counter_hook, + # deconv + _ConvTransposeMixin: deconv_flops_counter_hook, + # fc + nn.Linear: linear_flops_counter_hook, + # pooling + _AvgPoolNd: pool_flops_counter_hook, + _MaxPoolNd: pool_flops_counter_hook, + _AdaptiveAvgPoolNd: pool_flops_counter_hook, + _AdaptiveMaxPoolNd: pool_flops_counter_hook, + # activation + nn.ReLU: relu_flops_counter_hook, + nn.PReLU: relu_flops_counter_hook, + nn.ELU: relu_flops_counter_hook, + nn.LeakyReLU: relu_flops_counter_hook, + nn.ReLU6: relu_flops_counter_hook, + # normalization + _BatchNorm: bn_flops_counter_hook, + nn.GroupNorm: gn_flops_counter_hook, + # upsample + nn.Upsample: upsample_flops_counter_hook, +} + + def batch_counter_hook(module, input, output): batch_size = 1 if len(input) > 0: @@ -372,7 +395,6 @@ def batch_counter_hook(module, input, output): def add_batch_counter_variables_or_reset(module): - module.__batch_counter__ = 0 @@ -400,22 +422,11 @@ def add_flops_counter_hook_function(module): if hasattr(module, '__flops_handle__'): return - if isinstance(module, CONV_TYPES): - handle = module.register_forward_hook(conv_flops_counter_hook) - elif isinstance(module, RELU_TYPES): - handle = module.register_forward_hook(relu_flops_counter_hook) - elif isinstance(module, LINEAR_TYPES): - handle = module.register_forward_hook(linear_flops_counter_hook) - elif isinstance(module, POOLING_TYPES): - handle = module.register_forward_hook(pool_flops_counter_hook) - elif isinstance(module, BN_TYPES): - handle = module.register_forward_hook(bn_flops_counter_hook) - elif isinstance(module, UPSAMPLE_TYPES): - handle = module.register_forward_hook(upsample_flops_counter_hook) - elif isinstance(module, DECONV_TYPES): - handle = module.register_forward_hook(deconv_flops_counter_hook) - else: - handle = module.register_forward_hook(empty_flops_counter_hook) + for mod_type, counter_hook in hook_mapping.items(): + if issubclass(type(module), mod_type): + handle = module.register_forward_hook(counter_hook) + break + module.__flops_handle__ = handle diff --git a/tools/get_flops.py b/tools/get_flops.py index e64bac6dc5ea2a0011991c6352e716971569a17d..6c9cb23400c0705289d5ca3bb3430bce6c884a2e 100644 --- a/tools/get_flops.py +++ b/tools/get_flops.py @@ -46,6 +46,9 @@ def main(): split_line = '=' * 30 print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( split_line, input_shape, flops, params)) + print('!!!Please be cautious if you use the results in papers. ' + 'You may need to check if all ops are supported and verify that the ' + 'flops computation is correct.') if __name__ == '__main__':