Skip to content

Commit

Permalink
add option to count MAC (default to True) instead of flops (facebookr…
Browse files Browse the repository at this point in the history
…esearch#77)

Summary:
Pull Request resolved: facebookresearch#77

MACs and FLOPs are different concepts but often misused.
This should make it more clear that FlopCounter is actually counting MACs.
Maybe we should even change the default but that's a different decision to make.

Differential Revision: D28859722

fbshipit-source-id: 794b1410a53c6e72c83a4abb2d1d660cd78517a9
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Jul 8, 2021
1 parent 166a030 commit 726a919
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 42 deletions.
63 changes: 48 additions & 15 deletions fvcore/nn/flop_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# pyre-ignore-all-errors[2,33]

from collections import defaultdict
from copy import deepcopy
from typing import Any, Counter, DefaultDict, Dict, Optional, Tuple, Union

import torch.nn as nn
from torch import Tensor

from .jit_analysis import JitModelAnalysis
from .jit_analysis import JitModelAnalysis, Statistics
from .jit_handles import (
Handle,
addmm_flop_jit,
Expand All @@ -30,32 +31,41 @@
"aten::einsum": einsum_flop_jit,
"aten::matmul": matmul_flop_jit,
"aten::linear": linear_flop_jit,
# Flops for the following ops are just estimates as they are not very well
# defined and don't correlate with wall time very much. They shouldn't take
# a big portion of any model anyway.
# You might want to ignore BN flops due to inference-time fusion.
# Use `set_op_handle("aten::batch_norm", None)
"aten::batch_norm": batchnorm_flop_jit,
"aten::group_norm": norm_flop_counter(2),
"aten::layer_norm": norm_flop_counter(2),
"aten::instance_norm": norm_flop_counter(1),
"aten::upsample_nearest2d": elementwise_flop_counter(0, 1),
"aten::upsample_bilinear2d": elementwise_flop_counter(0, 4),
"aten::adaptive_avg_pool2d": elementwise_flop_counter(1, 0),
"aten::grid_sampler": elementwise_flop_counter(0, 4), # assume bilinear
"aten::upsample_bilinear2d": elementwise_flop_counter(0, 8),
"aten::adaptive_avg_pool2d": elementwise_flop_counter(2, 0),
"aten::grid_sampler": elementwise_flop_counter(0, 8), # assume bilinear
}


class FlopCountAnalysis(JitModelAnalysis):
"""
Provides access to per-submodule model flop count obtained by
tracing a model with pytorch's jit tracing functionality. By default,
comes with standard flop counters for a few common operators.
Note that:
Provides access to per-submodule flop count obtained by tracing a model
with pytorch's jit tracing functionality. By default, comes with standard
flop counters for a few common operators.
1. Flop is not a well-defined concept. We just produce our best estimate.
2. We count one fused multiply-add as one flop.
Flop represents floating point operations. Another common metric is MAC
(multiply-add count), which represents a multiply and an add operations.
We count MAC (multiply-add counts) by default, but this can be changed
by `set_use_mac(False)`. We just assume MAC is half of flops, which
is true for most expensive operators we care.
Note that flop/MAC is not a well-defined concept for many ops. We just produce
our best estimate.
Handles for additional operators may be added, or the default ones
overwritten, using the ``.set_op_handle(name, func)`` method.
See the method documentation for details.
The handler for each op should always calculate flops instead of MAC.
Flop counts can be obtained as:
Expand Down Expand Up @@ -111,6 +121,28 @@ def __init__(
) -> None:
super().__init__(model=model, inputs=inputs)
self.set_op_handle(**_DEFAULT_SUPPORTED_OPS)
self._use_mac = True # NOTE: maybe we'll want to change the default to False

def set_use_mac(self, enabled: bool) -> "FlopCountAnalysis":
"""
Decide whether to count MAC (multiply-add counts) rather than flops.
Default to True because this is the convention in many computer vision papers.
Unfortunately this concept is typically misused as flops.
To implement counting of MAC, we simply assume MAC is half of flops.
Although we note that this is not true for all ops.
"""
self._use_mac = enabled
return self

def _analyze(self) -> Statistics:
stats = super()._analyze()
if self._use_mac:
stats = deepcopy(stats)
for v in stats.counts.values():
for k in list(v.keys()):
v[k] = v[k] // 2
return stats

__init__.__doc__ = JitModelAnalysis.__init__.__doc__

Expand All @@ -121,8 +153,10 @@ def flop_count(
supported_ops: Optional[Dict[str, Handle]] = None,
) -> Tuple[DefaultDict[str, float], Counter[str]]:
"""
Given a model and an input to the model, compute the per-operator Gflops
of the given model.
Given a model and an input to the model, compute the per-operator GMACs
(10^9 multiply-adds) of the given model.
For more features and customized counting, please use :class:`FlopCountAnalysis`.
Args:
model (nn.Module): The model to compute flop counts.
Expand All @@ -131,12 +165,11 @@ def flop_count(
supported_ops (dict(str,Callable) or None) : provide additional
handlers for extra ops, or overwrite the existing handlers for
convolution and matmul and einsum. The key is operator name and the value
is a function that takes (inputs, outputs) of the op. We count
one Multiply-Add as one FLOP.
is a function that takes (inputs, outputs) of the op.
Returns:
tuple[defaultdict, Counter]: A dictionary that records the number of
gflops for each operation and a Counter that records the number of
GMACs for each operation and a Counter that records the number of
unsupported operations.
"""
if supported_ops is None:
Expand Down
26 changes: 14 additions & 12 deletions fvcore/nn/jit_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
batch_size, input_dim = input_shapes[0]
output_dim = input_shapes[1][1]
flops = batch_size * input_dim * output_dim
return flops
return flops * 2


def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
Expand All @@ -102,7 +102,7 @@ def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
# input_shapes[1]: [output_feature_dim, input_feature_dim]
assert input_shapes[0][-1] == input_shapes[1][-1]
flops = prod(input_shapes[0]) * input_shapes[1][0]
return flops
return flops * 2


def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
Expand All @@ -116,7 +116,7 @@ def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
n, c, t = input_shapes[0]
d = input_shapes[-1][-1]
flop = n * c * t * d
return flop
return flop * 2


def conv_flop_count(
Expand All @@ -137,7 +137,7 @@ def conv_flop_count(
out_size = prod(out_shape[2:])
kernel_size = prod(w_shape[2:])
flop = batch_size * out_size * Cout_dim * Cin_dim * kernel_size
return flop
return flop * 2


def conv_flop_jit(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]:
Expand Down Expand Up @@ -181,20 +181,19 @@ def einsum_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
n, c, t = input_shapes[0]
p = input_shapes[-1][-1]
flop = n * c * t * p
return flop
return flop * 2

elif equation == "abc,adc->adb":
n, t, g = input_shapes[0]
c = input_shapes[-1][1]
flop = n * t * g * c
return flop
return flop * 2
else:
np_arrs = [np.zeros(s) for s in input_shapes]
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
for line in optim.split("\n"):
if "optimized flop" in line.lower():
# divided by 2 because we count MAC (multiply-add counted as one flop)
flop = float(np.floor(float(line.split(":")[-1]) / 2))
flop = float(line.split(":")[-1].strip())
return flop
raise NotImplementedError("Unsupported einsum operation.")

Expand All @@ -209,7 +208,7 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
flop = prod(input_shapes[0]) * input_shapes[-1][-1]
return flop
return flop * 2


def norm_flop_counter(affine_arg_index: int) -> Handle:
Expand All @@ -226,8 +225,11 @@ def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
input_shape = get_shape(inputs[0])
has_affine = get_shape(inputs[affine_arg_index]) is not None
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = prod(input_shape) * (5 if has_affine else 4)
# 5 or 7 is just a rough estimate:
# 3 - compute E[x] and E[x^2]
# 2 - compute normalization
# 2 - compute affine
flop = prod(input_shape) * (7 if has_affine else 5)
return flop

return norm_flop_jit
Expand All @@ -240,7 +242,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
return norm_flop_counter(1)(inputs, outputs) # pyre-ignore
has_affine = get_shape(inputs[1]) is not None
input_shape = prod(get_shape(inputs[0]))
return input_shape * (2 if has_affine else 1)
return input_shape * (4 if has_affine else 2)


def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Handle:
Expand Down
22 changes: 11 additions & 11 deletions tests/test_flop_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def dummy_sigmoid_flop_jit(
custom_ops: Dict[str, Handle] = {"aten::sigmoid": dummy_sigmoid_flop_jit}
x = torch.rand(batch_size, input_dim)
flop_dict1, _ = flop_count(customNet, (x,), supported_ops=custom_ops)
flop_sigmoid = 10000 / 1e9
flop_sigmoid = 10000 / 1e9 / 2
self.assertEqual(
flop_dict1["sigmoid"],
flop_sigmoid,
Expand All @@ -211,7 +211,7 @@ def addmm_dummy_flop_jit(
"aten::{}".format(self.lin_op): addmm_dummy_flop_jit
}
flop_dict2, _ = flop_count(customNet, (x,), supported_ops=custom_ops2)
flop = 400000 / 1e9
flop = 400000 / 1e9 / 2
self.assertEqual(
flop_dict2[self.lin_op],
flop,
Expand Down Expand Up @@ -632,7 +632,7 @@ def test_batchnorm(self) -> None:
batch_2d = nn.BatchNorm2d(input_dim, affine=False)
x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y)
flop_dict, _ = flop_count(batch_2d, (x,))
gt_flop = 4 * batch_size * input_dim * spatial_dim_x * spatial_dim_y / 1e9
gt_flop = 2.5 * batch_size * input_dim * spatial_dim_x * spatial_dim_y / 1e9
gt_dict = defaultdict(float)
gt_dict["batch_norm"] = gt_flop
self.assertDictEqual(
Expand All @@ -651,7 +651,7 @@ def test_batchnorm(self) -> None:
)
flop_dict, _ = flop_count(batch_3d, (x,))
gt_flop = (
4
2.5
* batch_size
* input_dim
* spatial_dim_x
Expand Down Expand Up @@ -740,22 +740,22 @@ def test_batch_norm(self):
nodes = self._count_function(
F.batch_norm, (torch.rand(2, 2, 2, 2), vec, vec, vec, vec), op_name
)
self.assertEqual(counter(*nodes), 32)
self.assertEqual(counter(*nodes), 64)

nodes = self._count_function(
F.batch_norm,
(torch.rand(2, 2, 2, 2), vec, vec, None, None),
op_name,
)
self.assertEqual(counter(*nodes), 16)
self.assertEqual(counter(*nodes), 32)

nodes = self._count_function(
# training=True
F.batch_norm,
(torch.rand(2, 2, 2, 2), vec, vec, vec, vec, True),
op_name,
)
self.assertEqual(counter(*nodes), 80)
self.assertEqual(counter(*nodes), 112)

def test_group_norm(self):
op_name = "aten::group_norm"
Expand All @@ -765,12 +765,12 @@ def test_group_norm(self):
nodes = self._count_function(
F.group_norm, (torch.rand(2, 2, 2, 2), 2, vec, vec), op_name
)
self.assertEqual(counter(*nodes), 80)
self.assertEqual(counter(*nodes), 112)

nodes = self._count_function(
F.group_norm, (torch.rand(2, 2, 2, 2), 2, None, None), op_name
)
self.assertEqual(counter(*nodes), 64)
self.assertEqual(counter(*nodes), 80)

def test_upsample(self):
op_name = "aten::upsample_bilinear2d"
Expand All @@ -779,7 +779,7 @@ def test_upsample(self):
nodes = self._count_function(
F.interpolate, (torch.rand(2, 2, 2, 2), None, 2, "bilinear", False), op_name
)
self.assertEqual(counter(*nodes), 2 ** 4 * 4 * 4)
self.assertEqual(counter(*nodes), 2 ** 4 * 4 * 4 * 2)

def test_complicated_einsum(self):
op_name = "aten::einsum"
Expand All @@ -790,4 +790,4 @@ def test_complicated_einsum(self):
("nc,nchw->hw", torch.rand(3, 4), torch.rand(3, 4, 2, 3)),
op_name,
)
self.assertEqual(counter(*nodes), 72.0)
self.assertEqual(counter(*nodes), 145)
8 changes: 4 additions & 4 deletions tests/test_jit_model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def test_changing_handles(self) -> None:
"aten::linear": linear_flop_jit,
} # type: Dict[str, Handle]

analyzer = JitModelAnalysis(model=model, inputs=inputs).set_op_handle(
analyzer = FlopCountAnalysis(model=model, inputs=inputs).set_op_handle(
**op_handles
)
analyzer.unsupported_ops_warnings(enabled=False)
Expand All @@ -638,7 +638,7 @@ def make_dummy_op(name: str, output: int) -> Handle:
def dummy_ops_handle(
inputs: List[Any], outputs: List[Any]
) -> typing.Counter[str]:
return Counter({name: output})
return Counter({name: output * 2})

return dummy_ops_handle

Expand Down Expand Up @@ -725,10 +725,10 @@ def test_copy(self) -> None:
non_forward_flops = new_model.fc_flops + new_model.submod.fc_flops

# Total is correct for new model and inputs
self.assertEqual(analyzer_new.total(), non_forward_flops * bs)
self.assertEqual(analyzer_new.total(), non_forward_flops * bs * 2)

# Original is unaffected
self.assertEqual(analyzer.total(), repeated_net_flops)
self.assertEqual(analyzer.total(), repeated_net_flops * 2)

# Settings match
self.assertEqual(
Expand Down

0 comments on commit 726a919

Please sign in to comment.