Skip to content

Commit

Permalink
FLOPs (floating-point operations) module for ConvNets (#1808)
Browse files Browse the repository at this point in the history
* sort of inermediate point of implementing FLOPs, but might lose the machine

* more work on flops - update computations to match ac/dc

* implement analyzer in train script

* fix tests

* undo accidental change

* disable flops analyzer before running eval

* don't restart hooks

* make style changes

* more make quality

* style fix

* add correction for nonsquare kernels

* fix test and formatting

---------

Co-authored-by: Alexandre Marques <[email protected]>
  • Loading branch information
ohaijen and anmarques authored Mar 20, 2024
1 parent 5107108 commit a06a1a6
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 142 deletions.
240 changes: 107 additions & 133 deletions src/sparseml/pytorch/optim/analyzer_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.

"""
Code related to monitoring, analyzing, and reporting info for Modules in PyTorch.
Records things like FLOPS, input and output shapes, kernel shapes, etc.
Code for overall sparsity and forward FLOPs (floating-point operations)
estimation for neural networks.
"""

import numbers
from typing import List, Tuple, Union

import numpy
Expand Down Expand Up @@ -66,16 +67,29 @@ class ModuleAnalyzer(object):
:param module: the module to analyze
:param enabled: True to enable the hooks for analyzing and actively track,
False to disable and not track
:param ignore_zero: whether zeros should be excluded from FLOPs (standard
when estimating 'theoretical' FLOPs in sparse networks
: param
:param multiply_adds: Whether total flops includes the cost of summing the
multiplications together
"""

def __init__(self, module: Module, enabled: bool = False):
def __init__(
self,
module: Module,
enabled: bool = False,
ignore_zero=True,
multiply_adds=True,
):
super(ModuleAnalyzer, self).__init__()
self._module = module
self._hooks = None # type: List[RemovableHandle]
self._forward_called = False
self._enabled = False
self._call_count = -1
self.enabled = enabled
self._ignore_zero = ignore_zero
self._multiply_adds = multiply_adds

def __del__(self):
self._delete_hooks()
Expand Down Expand Up @@ -135,7 +149,7 @@ def ks_layer_descs(self) -> List[AnalyzedLayerDesc]:
"""
descs = []

for (name, _) in get_prunable_layers(self._module):
for name, _ in get_prunable_layers(self._module):
desc = self.layer_desc(name)

if desc is None:
Expand Down Expand Up @@ -212,10 +226,15 @@ def _forward_pre_hook(
):
self._call_count += 1

if mod._analyzed_layer_desc is not None:
return

mod._analyzed_layer_desc = AnalyzedLayerDesc(
name=mod._analyzed_layer_name,
type_=mod.__class__.__name__,
execution_order=self._call_count,
flops=0,
total_flops=0,
)

def _init_forward_hook(
Expand Down Expand Up @@ -258,35 +277,35 @@ def _conv_hook(
):
desc, inp, out = self._init_forward_hook(mod, inp, out)

params = (
{"weight": mod.weight}
if mod.bias is None
else {"weight": mod.weight, "bias": mod.bias}
desc.params = mod.weight.data.numel() + (
mod.bias.data.numel() if mod.bias is not None else 0
)
prunable_params = {"weight": mod.weight}
desc.prunable_params = mod.weight.data.numel()
desc.zeroed_params = desc.prunable_params - mod.weight.data.count_nonzero()

desc.params = sum([val.numel() for val in params.values()])
desc.prunable_params = sum([val.numel() for val in prunable_params.values()])
desc.zeroed_params = sum(
[(val == 0).sum().item() for val in prunable_params.values()]
batch_size, input_channels, input_height, input_width = inp[0].size()
_, output_channels, output_height, output_width = out[0].size()

bias_ops = 1 if mod.bias is not None else 0

num_weight_params = (
(mod.weight.data != 0.0).float().sum()
if self._ignore_zero
else mod.weight.data.nelement()
)
desc.params_dims = {
key: tuple(s for s in val.shape) for key, val in params.items()
}
desc.prunable_params_dims = {
key: tuple(s for s in val.shape) for key, val in prunable_params.items()
}
desc.stride = mod.stride

mult_per_out_pix = float(numpy.prod(mod.kernel_size)) * mod.in_channels
add_per_out_pix = 1 if mod.bias is not None else 0
out_pix = float(numpy.prod(out[0].shape[1:]))

# total flops counts the cost of summing the
# multiplications together as well
# most implementations and papers do not include this cost
desc.flops = (mult_per_out_pix + add_per_out_pix) * out_pix
desc.total_flops = (mult_per_out_pix * 2 + add_per_out_pix) * out_pix

flops = (
(
num_weight_params * (2 if self._multiply_adds else 1)
+ bias_ops * output_channels
)
* output_height
* output_width
* batch_size
)

desc.flops = flops
desc.total_flops += desc.flops

def _linear_hook(
self,
Expand All @@ -296,34 +315,24 @@ def _linear_hook(
):
desc, inp, out = self._init_forward_hook(mod, inp, out)

params = (
{"weight": mod.weight}
if mod.bias is None
else {"weight": mod.weight, "bias": mod.bias}
desc.params = mod.weight.data.numel() + (
mod.bias.data.numel() if mod.bias is not None else 0
)
prunable_params = {"weight": mod.weight}
desc.prunable_params = mod.weight.data.numel()
desc.zeroed_params = desc.prunable_params - mod.weight.data.count_nonzero()

desc.params = sum([val.numel() for val in params.values()])
desc.prunable_params = sum([val.numel() for val in prunable_params.values()])
desc.zeroed_params = sum(
[(val == 0).sum().item() for val in prunable_params.values()]
batch_size = inp[0].size(0) if inp[0].dim() == 2 else 1

num_weight_params = (
(mod.weight.data != 0.0).float().sum()
if self._ignore_zero
else mod.weight.data.nelement()
)
desc.params_dims = {
key: tuple(s for s in val.shape) for key, val in params.items()
}
desc.prunable_params_dims = {
key: tuple(s for s in val.shape) for key, val in prunable_params.items()
}

mult_per_out_pix = mod.in_features
add_per_out_pix = 1 if mod.bias is not None else 0
out_pix = float(numpy.prod(out[0].shape[1:]))

# total flops counts the cost of summing the
# multiplications together as well
# most implementations and papers do not include this cost
desc.flops = (mult_per_out_pix + add_per_out_pix) * out_pix
desc.total_flops = (mult_per_out_pix * 2 + add_per_out_pix) * out_pix
weight_ops = num_weight_params * (2 if self._multiply_adds else 1)
bias_ops = mod.bias.nelement() if mod.bias is not None else 0

desc.flops = batch_size * (weight_ops + bias_ops)
desc.total_flops += desc.flops

def _bn_hook(
self,
Expand All @@ -333,28 +342,14 @@ def _bn_hook(
):
desc, inp, out = self._init_forward_hook(mod, inp, out)

params = (
{"weight": mod.weight}
if mod.bias is None
else {"weight": mod.weight, "bias": mod.bias}
desc.params = mod.weight.data.numel() + (
mod.bias.data.numel() if mod.bias is not None else 0
)
prunable_params = {}
desc.prunable_params = mod.weight.data.numel()
desc.zeroed_params = desc.prunable_params - mod.weight.data.count_nonzero()

desc.params = sum([val.numel() for val in params.values()])
desc.prunable_params = sum([val.numel() for val in prunable_params.values()])
desc.zeroed_params = sum(
[(val == 0).sum().item() for val in prunable_params.values()]
)
desc.params_dims = {
key: tuple(s for s in val.shape) for key, val in params.items()
}
desc.prunable_params_dims = {
key: tuple(s for s in val.shape) for key, val in prunable_params.items()
}

# 4 elementwise operations on the output space, just need to add all of them up
desc.flops = 4 * float(numpy.prod(out[0].shape[1:]))
desc.total_flops = desc.flops
desc.flops = 2 * float(inp[0].nelement())
desc.total_flops += desc.flops

def _pool_hook(
self,
Expand All @@ -365,26 +360,21 @@ def _pool_hook(
desc, inp, out = self._init_forward_hook(mod, inp, out)

params = {key: val for key, val in mod.named_parameters()}
prunable_params = {}

desc.params = sum([val.numel() for val in params.values()])
desc.prunable_params = sum([val.numel() for val in prunable_params.values()])
desc.zeroed_params = sum(
[(val == 0).sum().item() for val in prunable_params.values()]
)
desc.params_dims = {
key: tuple(s for s in val.shape) for key, val in params.items()
}
desc.prunable_params_dims = {
key: tuple(s for s in val.shape) for key, val in prunable_params.items()
}
desc.stride = mod.stride
desc.prunable_params = 0
desc.zeroed_params = 0

batch_size, input_channels, input_height, input_width = inp[0].size()
batch_size, output_channels, output_height, output_width = out[0].size()

flops_per_out_pix = float(numpy.prod(mod.kernel_size) + 1)
out_pix = float(numpy.prod(out[0].shape[1:]))
if isinstance(mod.kernel_size, numbers.Number) or mod.kernel_size.dim() == 1:
kernel_ops = mod.kernel_size * mod.kernel_size
else:
kernel_ops = numpy.prod(mod.kernel_size)
flops = kernel_ops * output_channels * output_height * output_width * batch_size

desc.flops = flops_per_out_pix * out_pix
desc.total_flops = desc.flops
desc.flops = flops
desc.total_flops += desc.flops

def _adaptive_pool_hook(
self,
Expand All @@ -395,20 +385,10 @@ def _adaptive_pool_hook(
desc, inp, out = self._init_forward_hook(mod, inp, out)

params = {key: val for key, val in mod.named_parameters()}
prunable_params = {}

desc.params = sum([val.numel() for val in params.values()])
desc.prunable_params = sum([val.numel() for val in prunable_params.values()])
desc.zeroed_params = sum(
[(val == 0).sum().item() for val in prunable_params.values()]
)
desc.params_dims = {
key: tuple(s for s in val.shape) for key, val in params.items()
}
desc.prunable_params_dims = {
key: tuple(s for s in val.shape) for key, val in prunable_params.items()
}
desc.stride = 1
desc.prunable_params = 0
desc.zeroed_params = 0

stride = tuple(
inp[0].shape[i] // out[0].shape[i] for i in range(2, len(inp[0].shape))
Expand All @@ -417,11 +397,14 @@ def _adaptive_pool_hook(
inp[0].shape[i] - (out[0].shape[i] - 1) * stride[i - 2]
for i in range(2, len(inp[0].shape))
)
flops_per_out_pix = float(numpy.prod(kernel_size))
out_pix = float(numpy.prod(out[0].shape[1:]))
kernel_ops = numpy.prod(kernel_size)

batch_size, output_channels, output_height, output_width = out[0].size()

desc.flops = flops_per_out_pix * out_pix
desc.total_flops = desc.flops
flops = kernel_ops * output_channels * output_height * output_width * batch_size

desc.flops = flops
desc.total_flops += desc.flops

def _activation_hook(
self,
Expand All @@ -432,24 +415,24 @@ def _activation_hook(
desc, inp, out = self._init_forward_hook(mod, inp, out)

params = {key: val for key, val in mod.named_parameters()}
prunable_params = {}

desc.params = sum([val.numel() for val in params.values()])
desc.prunable_params = sum([val.numel() for val in prunable_params.values()])
desc.zeroed_params = sum(
[(val == 0).sum().item() for val in prunable_params.values()]
)
desc.params_dims = {
key: tuple(s for s in val.shape) for key, val in params.items()
}
desc.prunable_params_dims = {
key: tuple(s for s in val.shape) for key, val in prunable_params.items()
}
desc.prunable_params = 0
desc.zeroed_params = 0

# making assumption that flops spent is one per element
# (so swish is counted the same activation ReLU)
desc.flops = float(numpy.prod(out[0].shape[1:]))
desc.total_flops = desc.flops
# FIXME (can't really be fixed). Some standard architectures,
# such as a standard ResNet use the same activation (ReLU) object
# for all of the places that it appears in the net, which works
# fine because it's stateless. But it makes it hard to count per-
# batch forward FLOPs correctly, since a single forward pass
# through the network is actually multiple passes trhough the
# activation. So the per-batch FLOPs are undercounted (slightly,
# since activations are very few FLOPs in general), but total
# (cumulative) FLOPs are counted correctly.
desc.flops = float(inp[0].nelement())
desc.total_flops += desc.flops

def _softmax_hook(
self,
Expand All @@ -460,25 +443,16 @@ def _softmax_hook(
desc, inp, out = self._init_forward_hook(mod, inp, out)

params = {key: val for key, val in mod.named_parameters()}
prunable_params = {}

desc.params = sum([val.numel() for val in params.values()])
desc.prunable_params = sum([val.numel() for val in prunable_params.values()])
desc.zeroed_params = sum(
[(val == 0).sum().item() for val in prunable_params.values()]
)
desc.params_dims = {
key: tuple(s for s in val.shape) for key, val in params.items()
}
desc.prunable_params_dims = {
key: tuple(s for s in val.shape) for key, val in prunable_params.items()
}
desc.prunable_params = 0
desc.zeroed_params = 0

flops_per_channel = (
2 if len(out[0].shape) < 3 else float(numpy.prod(out[0].shape[2:]))
)
desc.flops = flops_per_channel * out[0].shape[1]
desc.total_flops = desc.flops
desc.total_flops += desc.flops

@staticmethod
def _mod_desc(mod: Module) -> AnalyzedLayerDesc:
Expand Down
Loading

0 comments on commit a06a1a6

Please sign in to comment.