From a06a1a6e187a5ecb0c774ff336030c2b58452ab2 Mon Sep 17 00:00:00 2001 From: Jen Iofinova Date: Wed, 20 Mar 2024 19:44:52 +0100 Subject: [PATCH] FLOPs (floating-point operations) module for ConvNets (#1808) * sort of inermediate point of implementing FLOPs, but might lose the machine * more work on flops - update computations to match ac/dc * implement analyzer in train script * fix tests * undo accidental change * disable flops analyzer before running eval * don't restart hooks * make style changes * more make quality * style fix * add correction for nonsquare kernels * fix test and formatting --------- Co-authored-by: Alexandre Marques --- src/sparseml/pytorch/optim/analyzer_module.py | 240 ++++++++---------- src/sparseml/pytorch/torchvision/train.py | 26 ++ .../pytorch/optim/test_analyzer_module.py | 143 ++++++++++- 3 files changed, 267 insertions(+), 142 deletions(-) diff --git a/src/sparseml/pytorch/optim/analyzer_module.py b/src/sparseml/pytorch/optim/analyzer_module.py index c6f07e85864..b26f240355f 100644 --- a/src/sparseml/pytorch/optim/analyzer_module.py +++ b/src/sparseml/pytorch/optim/analyzer_module.py @@ -13,10 +13,11 @@ # limitations under the License. """ -Code related to monitoring, analyzing, and reporting info for Modules in PyTorch. -Records things like FLOPS, input and output shapes, kernel shapes, etc. +Code for overall sparsity and forward FLOPs (floating-point operations) +estimation for neural networks. """ +import numbers from typing import List, Tuple, Union import numpy @@ -66,9 +67,20 @@ class ModuleAnalyzer(object): :param module: the module to analyze :param enabled: True to enable the hooks for analyzing and actively track, False to disable and not track + :param ignore_zero: whether zeros should be excluded from FLOPs (standard + when estimating 'theoretical' FLOPs in sparse networks + : param + :param multiply_adds: Whether total flops includes the cost of summing the + multiplications together """ - def __init__(self, module: Module, enabled: bool = False): + def __init__( + self, + module: Module, + enabled: bool = False, + ignore_zero=True, + multiply_adds=True, + ): super(ModuleAnalyzer, self).__init__() self._module = module self._hooks = None # type: List[RemovableHandle] @@ -76,6 +88,8 @@ def __init__(self, module: Module, enabled: bool = False): self._enabled = False self._call_count = -1 self.enabled = enabled + self._ignore_zero = ignore_zero + self._multiply_adds = multiply_adds def __del__(self): self._delete_hooks() @@ -135,7 +149,7 @@ def ks_layer_descs(self) -> List[AnalyzedLayerDesc]: """ descs = [] - for (name, _) in get_prunable_layers(self._module): + for name, _ in get_prunable_layers(self._module): desc = self.layer_desc(name) if desc is None: @@ -212,10 +226,15 @@ def _forward_pre_hook( ): self._call_count += 1 + if mod._analyzed_layer_desc is not None: + return + mod._analyzed_layer_desc = AnalyzedLayerDesc( name=mod._analyzed_layer_name, type_=mod.__class__.__name__, execution_order=self._call_count, + flops=0, + total_flops=0, ) def _init_forward_hook( @@ -258,35 +277,35 @@ def _conv_hook( ): desc, inp, out = self._init_forward_hook(mod, inp, out) - params = ( - {"weight": mod.weight} - if mod.bias is None - else {"weight": mod.weight, "bias": mod.bias} + desc.params = mod.weight.data.numel() + ( + mod.bias.data.numel() if mod.bias is not None else 0 ) - prunable_params = {"weight": mod.weight} + desc.prunable_params = mod.weight.data.numel() + desc.zeroed_params = desc.prunable_params - mod.weight.data.count_nonzero() - desc.params = sum([val.numel() for val in params.values()]) - desc.prunable_params = sum([val.numel() for val in prunable_params.values()]) - desc.zeroed_params = sum( - [(val == 0).sum().item() for val in prunable_params.values()] + batch_size, input_channels, input_height, input_width = inp[0].size() + _, output_channels, output_height, output_width = out[0].size() + + bias_ops = 1 if mod.bias is not None else 0 + + num_weight_params = ( + (mod.weight.data != 0.0).float().sum() + if self._ignore_zero + else mod.weight.data.nelement() ) - desc.params_dims = { - key: tuple(s for s in val.shape) for key, val in params.items() - } - desc.prunable_params_dims = { - key: tuple(s for s in val.shape) for key, val in prunable_params.items() - } - desc.stride = mod.stride - - mult_per_out_pix = float(numpy.prod(mod.kernel_size)) * mod.in_channels - add_per_out_pix = 1 if mod.bias is not None else 0 - out_pix = float(numpy.prod(out[0].shape[1:])) - - # total flops counts the cost of summing the - # multiplications together as well - # most implementations and papers do not include this cost - desc.flops = (mult_per_out_pix + add_per_out_pix) * out_pix - desc.total_flops = (mult_per_out_pix * 2 + add_per_out_pix) * out_pix + + flops = ( + ( + num_weight_params * (2 if self._multiply_adds else 1) + + bias_ops * output_channels + ) + * output_height + * output_width + * batch_size + ) + + desc.flops = flops + desc.total_flops += desc.flops def _linear_hook( self, @@ -296,34 +315,24 @@ def _linear_hook( ): desc, inp, out = self._init_forward_hook(mod, inp, out) - params = ( - {"weight": mod.weight} - if mod.bias is None - else {"weight": mod.weight, "bias": mod.bias} + desc.params = mod.weight.data.numel() + ( + mod.bias.data.numel() if mod.bias is not None else 0 ) - prunable_params = {"weight": mod.weight} + desc.prunable_params = mod.weight.data.numel() + desc.zeroed_params = desc.prunable_params - mod.weight.data.count_nonzero() - desc.params = sum([val.numel() for val in params.values()]) - desc.prunable_params = sum([val.numel() for val in prunable_params.values()]) - desc.zeroed_params = sum( - [(val == 0).sum().item() for val in prunable_params.values()] + batch_size = inp[0].size(0) if inp[0].dim() == 2 else 1 + + num_weight_params = ( + (mod.weight.data != 0.0).float().sum() + if self._ignore_zero + else mod.weight.data.nelement() ) - desc.params_dims = { - key: tuple(s for s in val.shape) for key, val in params.items() - } - desc.prunable_params_dims = { - key: tuple(s for s in val.shape) for key, val in prunable_params.items() - } - - mult_per_out_pix = mod.in_features - add_per_out_pix = 1 if mod.bias is not None else 0 - out_pix = float(numpy.prod(out[0].shape[1:])) - - # total flops counts the cost of summing the - # multiplications together as well - # most implementations and papers do not include this cost - desc.flops = (mult_per_out_pix + add_per_out_pix) * out_pix - desc.total_flops = (mult_per_out_pix * 2 + add_per_out_pix) * out_pix + weight_ops = num_weight_params * (2 if self._multiply_adds else 1) + bias_ops = mod.bias.nelement() if mod.bias is not None else 0 + + desc.flops = batch_size * (weight_ops + bias_ops) + desc.total_flops += desc.flops def _bn_hook( self, @@ -333,28 +342,14 @@ def _bn_hook( ): desc, inp, out = self._init_forward_hook(mod, inp, out) - params = ( - {"weight": mod.weight} - if mod.bias is None - else {"weight": mod.weight, "bias": mod.bias} + desc.params = mod.weight.data.numel() + ( + mod.bias.data.numel() if mod.bias is not None else 0 ) - prunable_params = {} + desc.prunable_params = mod.weight.data.numel() + desc.zeroed_params = desc.prunable_params - mod.weight.data.count_nonzero() - desc.params = sum([val.numel() for val in params.values()]) - desc.prunable_params = sum([val.numel() for val in prunable_params.values()]) - desc.zeroed_params = sum( - [(val == 0).sum().item() for val in prunable_params.values()] - ) - desc.params_dims = { - key: tuple(s for s in val.shape) for key, val in params.items() - } - desc.prunable_params_dims = { - key: tuple(s for s in val.shape) for key, val in prunable_params.items() - } - - # 4 elementwise operations on the output space, just need to add all of them up - desc.flops = 4 * float(numpy.prod(out[0].shape[1:])) - desc.total_flops = desc.flops + desc.flops = 2 * float(inp[0].nelement()) + desc.total_flops += desc.flops def _pool_hook( self, @@ -365,26 +360,21 @@ def _pool_hook( desc, inp, out = self._init_forward_hook(mod, inp, out) params = {key: val for key, val in mod.named_parameters()} - prunable_params = {} - desc.params = sum([val.numel() for val in params.values()]) - desc.prunable_params = sum([val.numel() for val in prunable_params.values()]) - desc.zeroed_params = sum( - [(val == 0).sum().item() for val in prunable_params.values()] - ) - desc.params_dims = { - key: tuple(s for s in val.shape) for key, val in params.items() - } - desc.prunable_params_dims = { - key: tuple(s for s in val.shape) for key, val in prunable_params.items() - } - desc.stride = mod.stride + desc.prunable_params = 0 + desc.zeroed_params = 0 + + batch_size, input_channels, input_height, input_width = inp[0].size() + batch_size, output_channels, output_height, output_width = out[0].size() - flops_per_out_pix = float(numpy.prod(mod.kernel_size) + 1) - out_pix = float(numpy.prod(out[0].shape[1:])) + if isinstance(mod.kernel_size, numbers.Number) or mod.kernel_size.dim() == 1: + kernel_ops = mod.kernel_size * mod.kernel_size + else: + kernel_ops = numpy.prod(mod.kernel_size) + flops = kernel_ops * output_channels * output_height * output_width * batch_size - desc.flops = flops_per_out_pix * out_pix - desc.total_flops = desc.flops + desc.flops = flops + desc.total_flops += desc.flops def _adaptive_pool_hook( self, @@ -395,20 +385,10 @@ def _adaptive_pool_hook( desc, inp, out = self._init_forward_hook(mod, inp, out) params = {key: val for key, val in mod.named_parameters()} - prunable_params = {} desc.params = sum([val.numel() for val in params.values()]) - desc.prunable_params = sum([val.numel() for val in prunable_params.values()]) - desc.zeroed_params = sum( - [(val == 0).sum().item() for val in prunable_params.values()] - ) - desc.params_dims = { - key: tuple(s for s in val.shape) for key, val in params.items() - } - desc.prunable_params_dims = { - key: tuple(s for s in val.shape) for key, val in prunable_params.items() - } - desc.stride = 1 + desc.prunable_params = 0 + desc.zeroed_params = 0 stride = tuple( inp[0].shape[i] // out[0].shape[i] for i in range(2, len(inp[0].shape)) @@ -417,11 +397,14 @@ def _adaptive_pool_hook( inp[0].shape[i] - (out[0].shape[i] - 1) * stride[i - 2] for i in range(2, len(inp[0].shape)) ) - flops_per_out_pix = float(numpy.prod(kernel_size)) - out_pix = float(numpy.prod(out[0].shape[1:])) + kernel_ops = numpy.prod(kernel_size) + + batch_size, output_channels, output_height, output_width = out[0].size() - desc.flops = flops_per_out_pix * out_pix - desc.total_flops = desc.flops + flops = kernel_ops * output_channels * output_height * output_width * batch_size + + desc.flops = flops + desc.total_flops += desc.flops def _activation_hook( self, @@ -432,24 +415,24 @@ def _activation_hook( desc, inp, out = self._init_forward_hook(mod, inp, out) params = {key: val for key, val in mod.named_parameters()} - prunable_params = {} desc.params = sum([val.numel() for val in params.values()]) - desc.prunable_params = sum([val.numel() for val in prunable_params.values()]) - desc.zeroed_params = sum( - [(val == 0).sum().item() for val in prunable_params.values()] - ) - desc.params_dims = { - key: tuple(s for s in val.shape) for key, val in params.items() - } - desc.prunable_params_dims = { - key: tuple(s for s in val.shape) for key, val in prunable_params.items() - } + desc.prunable_params = 0 + desc.zeroed_params = 0 # making assumption that flops spent is one per element # (so swish is counted the same activation ReLU) - desc.flops = float(numpy.prod(out[0].shape[1:])) - desc.total_flops = desc.flops + # FIXME (can't really be fixed). Some standard architectures, + # such as a standard ResNet use the same activation (ReLU) object + # for all of the places that it appears in the net, which works + # fine because it's stateless. But it makes it hard to count per- + # batch forward FLOPs correctly, since a single forward pass + # through the network is actually multiple passes trhough the + # activation. So the per-batch FLOPs are undercounted (slightly, + # since activations are very few FLOPs in general), but total + # (cumulative) FLOPs are counted correctly. + desc.flops = float(inp[0].nelement()) + desc.total_flops += desc.flops def _softmax_hook( self, @@ -460,25 +443,16 @@ def _softmax_hook( desc, inp, out = self._init_forward_hook(mod, inp, out) params = {key: val for key, val in mod.named_parameters()} - prunable_params = {} desc.params = sum([val.numel() for val in params.values()]) - desc.prunable_params = sum([val.numel() for val in prunable_params.values()]) - desc.zeroed_params = sum( - [(val == 0).sum().item() for val in prunable_params.values()] - ) - desc.params_dims = { - key: tuple(s for s in val.shape) for key, val in params.items() - } - desc.prunable_params_dims = { - key: tuple(s for s in val.shape) for key, val in prunable_params.items() - } + desc.prunable_params = 0 + desc.zeroed_params = 0 flops_per_channel = ( 2 if len(out[0].shape) < 3 else float(numpy.prod(out[0].shape[2:])) ) desc.flops = flops_per_channel * out[0].shape[1] - desc.total_flops = desc.flops + desc.total_flops += desc.flops @staticmethod def _mod_desc(mod: Module) -> AnalyzedLayerDesc: diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 53c17dee795..70ac270299e 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -48,6 +48,7 @@ from sparseml.optim.helpers import load_recipe_yaml_str from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.optim import ScheduledModifierManager +from sparseml.pytorch.optim.analyzer_module import ModuleAnalyzer from sparseml.pytorch.torchvision import presets, transforms, utils from sparseml.pytorch.torchvision.sampler import RASampler from sparseml.pytorch.utils.helpers import ( @@ -93,6 +94,7 @@ def train_one_epoch( manager=None, model_ema=None, scaler=None, + flops_analyzer=None, ) -> utils.MetricLogger: accum_steps = args.gradient_accum_steps @@ -105,6 +107,13 @@ def train_one_epoch( metric_logger.add_meter("loss", utils.SmoothedValue(window_size=accum_steps)) metric_logger.add_meter("acc1", utils.SmoothedValue(window_size=accum_steps)) metric_logger.add_meter("acc5", utils.SmoothedValue(window_size=accum_steps)) + if flops_analyzer is not None: + metric_logger.add_meter( + "flops_per_epoch", utils.SmoothedValue(window_size=1, fmt="{value}") + ) + metric_logger.add_meter( + "total_flops", utils.SmoothedValue(window_size=1, fmt="{value}") + ) steps_accumulated = 0 num_optim_steps = 0 @@ -178,6 +187,10 @@ def train_one_epoch( metric_logger.meters["imgs_per_sec"].update( batch_size / (time.time() - start_time) ) + if flops_analyzer is not None: + layer_sparsity = ModuleAnalyzer._mod_desc(model) + metric_logger.meters["total_flops"].update(layer_sparsity.total_flops) + metric_logger.meters["flops_per_epoch"].update(layer_sparsity.flops) if args.eval_steps is not None and num_optim_steps % args.eval_steps == 0: eval_metrics = evaluate(model, criterion, data_loader_test, device) @@ -698,6 +711,9 @@ def data_loader_builder(**kwargs): start_time = time.time() max_epochs = manager.max_epochs if manager is not None else args.epochs + flops_analyzer = None + if args.track_flops: + flops_analyzer = ModuleAnalyzer(model, enabled=True) for epoch in range(args.start_epoch, max_epochs): if args.distributed: train_sampler.set_epoch(epoch) @@ -719,6 +735,7 @@ def data_loader_builder(**kwargs): manager=manager, model_ema=model_ema, scaler=scaler, + flops_analyzer=flops_analyzer, ) log_metrics("Train", train_metrics, epoch, steps_per_epoch) @@ -1281,6 +1298,15 @@ def new_func(*args, **kwargs): "Note: Will use ImageNet values if not specified." ), ) +@click.option( + "--track-flops", + is_flag=True, + default=False, + help=( + "If true, estimate FLOPs (floating point operations) of forward " + "passes during training." + ), +) @click.pass_context def cli(ctx, **kwargs): """ diff --git a/tests/sparseml/pytorch/optim/test_analyzer_module.py b/tests/sparseml/pytorch/optim/test_analyzer_module.py index 75539dc8e6d..d52f3680ba0 100644 --- a/tests/sparseml/pytorch/optim/test_analyzer_module.py +++ b/tests/sparseml/pytorch/optim/test_analyzer_module.py @@ -17,7 +17,8 @@ import pytest import torch -from torch.nn import Module +from torch.nn import Linear, Module +from torch.nn.modules.conv import _ConvNd from torchvision.models import resnet50 from sparseml.pytorch.optim import ModuleAnalyzer @@ -38,7 +39,7 @@ 2800, 2688, 0, - 2912, + 5600, 5600, ), ( @@ -48,7 +49,7 @@ 544, 512, 4, - 544, + 1056, 1056, ), ( @@ -68,7 +69,7 @@ 5418, 5360, 0, - 321780, + 632564, 632564, ), ( @@ -78,7 +79,7 @@ 4640, 4608, 4, - 227360, + 453152, 453152, ), ( @@ -86,10 +87,13 @@ (3, 224, 224), None, 25557032, - 25502912, + 25529472, 0, - 4140866536, - 8230050792, + 8208826344, + # RN50 has a single ReLU used multiple + # times, so total is not equal to single-pass + # FLOPs, even for a single step. + 8212112872, ), ], ) @@ -103,7 +107,127 @@ def test_analyzer( flops: int, total_flops: int, ): - analyzer = ModuleAnalyzer(model, enabled=True) + # Make sure we don't accidentally have 0 weights in a + # 'dense' model. In real life it's fine, but here it would + # throw off the expected result. + def init_weights(m): + if isinstance(m, Linear) or isinstance(m, _ConvNd): + m.weight.data.fill_(0.01) + if m.bias is not None: + m.bias.data.fill_(0.01) + + model.apply(init_weights) + analyzer = ModuleAnalyzer(model, enabled=True, ignore_zero=True) + tens = torch.randn(1, *input_shape) + out = model(tens) + analyzer.enabled = False + out = model(tens) + assert len(out) + + desc = analyzer.layer_desc(name) + assert desc.params == params + assert desc.prunable_params == prunable_params + assert desc.zeroed_params == 0 + assert desc.execution_order == execution_order + assert desc.flops == flops + assert desc.total_flops == total_flops + + +@pytest.mark.parametrize( + "model,input_shape,name,params,prunable_params,zeroed_params,execution_order,flops,total_flops", # noqa: E501 + [ + ( + MLPNet(), + MLPNet.layer_descs()[0].input_size, + None, + 2800, + 2688, + 56, + 0, + 5488, + 5488, + ), + ( + MLPNet(), + MLPNet.layer_descs()[0].input_size, + MLPNet.layer_descs()[2].name, + 544, + 512, + 16, + 4, + 1024, + 1024, + ), + ( + MLPNet(), + MLPNet.layer_descs()[0].input_size, + MLPNet.layer_descs()[3].name, + 0, + 0, + 0, + 5, + 32, + 32, + ), + ( + ConvNet(), + ConvNet.layer_descs()[0].input_size, + None, + 5418, + 5360, + 203, + 0, + 607804, + 607804, + ), + ( + ConvNet(), + ConvNet.layer_descs()[0].input_size, + ConvNet.layer_descs()[2].name, + 4640, + 4608, + 144, + 4, + 439040, + 439040, + ), + ( + resnet50(), + (3, 224, 224), + None, + 25557032, + 25529472, + 54931, + 0, + 8165194368, + # RN50 has a single ReLU used multiple + # times, so total is not equal to single-pass + # FLOPs, even for a single step. + 8168480896, + ), + ], +) +def test_analyzer_sparse( + model: Module, + input_shape: Tuple[int], + name: str, + params: int, + prunable_params: int, + zeroed_params: int, + execution_order: int, + flops: int, + total_flops: int, +): + def init_weights(m): + if isinstance(m, Linear) or isinstance(m, _ConvNd): + m.weight.data.fill_(0.01) + # Set some weights to 0 + m.weight.data[0] = 0 + if m.bias is not None: + m.bias.data.fill_(0.01) + + model.apply(init_weights) + analyzer = ModuleAnalyzer(model, enabled=True, ignore_zero=True) tens = torch.randn(1, *input_shape) out = model(tens) analyzer.enabled = False @@ -113,6 +237,7 @@ def test_analyzer( desc = analyzer.layer_desc(name) assert desc.params == params assert desc.prunable_params == prunable_params + assert desc.zeroed_params == zeroed_params assert desc.execution_order == execution_order assert desc.flops == flops assert desc.total_flops == total_flops