Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Bfloat16 support #777

Merged
merged 3 commits into from
Dec 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,8 @@
from brevitas import config
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue

from .stats_wrapper import SCALAR_SHAPE

@@ -64,15 +66,15 @@ def forward(self, x: Tensor):
if self.stats_reduce_dim is None:
# k is 1-indexed, so round away from zero
k = int(math.floor(.01 * self.q * x.numel() + 0.5))
result = x.abs().view(-1).kthvalue(k).values
result = kthvalue(x.abs().view(-1), k)[0]
else:
# assuming x is two dimensional, get the other dimension
assert len(x.size()) == 2, "Only 2-dim input is supported."
other_dim = abs(self.stats_reduce_dim - 1)
dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1)
# k is 1-indexed, so round away from zero
k = int(math.floor(.01 * self.q * dim_slice.numel() + 0.5))
result = x.abs().kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
result = kthvalue(x.abs(), k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
return result


@@ -97,15 +99,15 @@ def forward(self, x: Tensor) -> Tensor:
if self.stats_reduce_dim is None:
# k is 1-indexed, so round away from zero
k = int(math.ceil(.01 * self.q * x.numel()))
result = x.view(-1).kthvalue(k).values
result = kthvalue(x.view(-1), k)[0]
else:
# assuming x is two dimensional, get the other dimension
assert len(x.size()) == 2, "Only 2-dim input is supported."
other_dim = abs(self.stats_reduce_dim - 1)
dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1)
# k is 1-indexed, so round away from zero
k = int(math.ceil(.01 * self.q * dim_slice.numel()))
result = x.kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
result = kthvalue(x, k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
result = torch.clamp(result, max=self.zero())
return result

@@ -134,8 +136,8 @@ def forward(self, x: Tensor) -> Tensor:
low_k = int(math.ceil(.01 * self.low_q * x.numel()))
# k is 1-indexed, so round away from zero
high_k = int(math.floor(.01 * self.high_q * x.numel() + 0.5))
low_result = x.view(-1).kthvalue(low_k).values
high_result = x.view(-1).kthvalue(high_k).values
low_result = kthvalue(x.view(-1), low_k)[0]
high_result = kthvalue(x.view(-1), high_k)[0]
else:
# assuming x is two dimensional, get the other dimension
assert len(x.size()) == 2, "Only 2-dim input is supported."
@@ -144,8 +146,8 @@ def forward(self, x: Tensor) -> Tensor:
low_k = int(math.ceil(.01 * self.low_q * dim_slice.numel()))
# k is 1-indexed, so round away from zero
high_k = int(math.floor(.01 * self.high_q * dim_slice.numel() + 0.5))
low_result = x.kthvalue(low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
high_result = x.kthvalue(high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
low_result = kthvalue(x, low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
high_result = kthvalue(x, high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
# We need to make sure the lower bound is not positive to align with zero-point statistics
low_result = torch.clamp(low_result, max=self.zero())
interval = high_result - low_result
28 changes: 23 additions & 5 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
from .torch_handler import QUANT_TENSOR_FN_HANDLER

IS_VALID_ATOL = 2e-1
BFLOAT16_IS_VALID_ATOL = 0.5


class QuantTensorBase(NamedTuple):
@@ -104,8 +105,15 @@ def is_not_none(self):

@property
def _pre_round_int_value(self):
int_value = self.value / self.scale
int_value = int_value + self.zero_point
value = self.value
scale = self.scale
zero_point = self.zero_point
if self.scale.dtype == torch.bfloat16:
value = self.value.type(torch.float32)
scale = self.scale.type(torch.float32)
zero_point = self.zero_point.type(torch.float32)
int_value = value / scale
int_value = int_value + zero_point
return int_value

@property
@@ -114,8 +122,9 @@ def is_valid(self):
with torch.no_grad():
pre_round_int_value = self._pre_round_int_value
rounded_int_value = torch.round(pre_round_int_value)
is_int = torch.isclose(
pre_round_int_value, rounded_int_value, atol=IS_VALID_ATOL).all()
max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value))
atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL
is_int = max_abs_diff < atol
if self.bit_width >= 2:
if self.signed:
is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all()
@@ -176,7 +185,12 @@ def int(self, float_datatype=False):
if self.is_valid:
int_value = round_ste(self._pre_round_int_value)
if float_datatype:
return int_value
# Values at 8bit and lower can be represented exactly with float16 and bfloat16
# otherwise (e.g. Int16 bias), we upscale to float32
if self.bit_width <= 8.:
return int_value.type(self.scale.dtype)
else:
return int_value.type(torch.float32)
else:
if self.bit_width <= 8. and self.signed_t.item():
return int_value.to(torch.int8)
@@ -301,6 +315,8 @@ def cat(tensors, dim, out=None):

def __neg__(self):
neg_value = (-self.int(float_datatype=True) - self.zero_point) * self.scale
# In case the dtype of self.int is different from the one of the scale
neg_value = neg_value.type(self.scale.dtype)
if self.signed:
return QuantTensor(
value=neg_value,
@@ -432,6 +448,8 @@ def __truediv__(self, other):
def __abs__(self):
if self.signed:
abs_value = (torch.abs(self.int(float_datatype=True)) - self.zero_point) * self.scale
# In case the dtype of self.int is different from the one of the scale
abs_value = abs_value.type(self.scale.dtype)
return QuantTensor(
value=abs_value,
scale=self.scale,
34 changes: 34 additions & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
from typing import Optional, Tuple

import torch
from torch.nn import Sequential
@@ -46,3 +47,36 @@ def torch_partial_deepcopy(model):
memo[id(p)] = copy.copy(p) # Shallow copy of parameters
model_copy = copy.deepcopy(model, memo)
return model_copy


def kthvalue(
x: torch.Tensor,
k: int,
dim: Optional[int] = None,
keepdim: bool = False,
out: Optional[Tuple[torch.Tensor, torch.LongTensor]] = None
) -> Tuple[torch.Tensor, torch.LongTensor]:
# As of torch 2.1, there is no kthvalue implementation:
# - In CPU for float16
# - In GPU for bfloat16
# In these cases we cast to float32 and then go back to the original dtype
dtype = x.dtype
device = str(x.device)

# We do not support out as buffer for the output, since we cannot control its dtype
if out is not None:
raise RuntimeError("out argument for kthvalue not supported")

if (dtype == torch.float16 and 'cpu' in device) or \
(dtype == torch.bfloat16 and 'cuda' in device):
x = x.type(torch.float32)

# PyTorch specify None as default for `dim` but it breaks if we specifically pass None
if dim is not None:
x, indices = torch.kthvalue(x, k, dim=dim, keepdim=keepdim)
else:
x, indices = torch.kthvalue(x, k, keepdim=keepdim)

if x.dtype != dtype:
x = x.type(dtype)
return (x, indices)
Original file line number Diff line number Diff line change
@@ -72,6 +72,8 @@
metavar='ARCH',
choices=model_names,
help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)')
parser.add_argument(
'--dtype', default='float', choices=['float', 'bfloat16'], help='Data type to use')
parser.add_argument(
'--target-backend',
default='fx',
@@ -215,6 +217,7 @@
default=None,
type=int,
help='Accumulator Bit Width for GPFA2Q (default: None)')
parser.add_argument('--onnx-opset-version', default=None, type=int, help='ONNX opset version')
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)')
@@ -226,11 +229,11 @@

def main():
args = parser.parse_args()
dtype = getattr(torch, args.dtype)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if args.act_quant_calibration_type == 'stats':
act_quant_calib_config = str(args.act_quant_percentile) + 'stats'
else:
@@ -312,14 +315,15 @@ def main():

# Get the model from torchvision
model = get_torchvision_model(args.model_name)
model = model.to(dtype)

# Preprocess the model for quantization
if args.target_backend == 'flexml':
# flexml requires static shapes, pass a representative input in
img_shape = model_config['center_crop_shape']
model = preprocess_for_flexml_quantize(
model,
torch.ones(1, 3, img_shape, img_shape),
torch.ones(1, 3, img_shape, img_shape, dtype=dtype),
equalize_iters=args.graph_eq_iterations,
equalize_merge_bias=args.graph_eq_merge_bias,
merge_bn=not args.calibrate_bn)
@@ -339,6 +343,7 @@ def main():
# Define the quantized model
quant_model = quantize_model(
model,
dtype=dtype,
backend=args.target_backend,
scale_factor_type=args.scale_factor_type,
bias_bit_width=args.bias_bit_width,
@@ -405,7 +410,7 @@ def main():

# Validate the quant_model on the validation dataloader
print("Starting validation:")
validate(val_loader, quant_model)
validate(val_loader, quant_model, stable=dtype != torch.bfloat16)

if args.export_onnx_qcdq or args.export_torch_qcdq:
# Generate reference input tensor to drive the export process
@@ -418,7 +423,7 @@ def main():
export_name = os.path.join(args.export_dir, config)
if args.export_onnx_qcdq:
export_name = export_name + '.onnx'
export_onnx_qcdq(model, ref_input, export_name)
export_onnx_qcdq(model, ref_input, export_name, opset_version=args.onnx_opset_version)
if args.export_torch_qcdq:
export_name = export_name + '.pt'
export_torch_qcdq(model, ref_input, export_name)
4 changes: 2 additions & 2 deletions src/brevitas_examples/imagenet_classification/utils.py
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@ def accuracy(output, target, topk=(1,), stable=False):
return res


def validate(val_loader, model):
def validate(val_loader, model, stable=True):
"""
Run validation on the desired dataset
"""
@@ -82,7 +82,7 @@ def print_accuracy(top1, prefix=''):

output = model(images)
# measure accuracy
acc1, = accuracy(output, target, stable=True)
acc1, = accuracy(output, target, stable=stable)
top1.update(acc1[0], images.size(0))

print_accuracy(top1, 'Total:')
6 changes: 4 additions & 2 deletions tests/brevitas/core/test_stats.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,8 @@
from brevitas.core.stats import AbsPercentile
from brevitas.core.stats import NegativePercentileOrZero
from brevitas.core.stats import PercentileInterval
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue


def test_abs_percentile_per_tensor():
@@ -35,10 +37,10 @@ def compute_percentile(self, x, low_q=None, high_q=None):
low_p, high_p = None, None
if low_q is not None:
k = int(math.ceil(.01 * low_q * x.numel()))
low_p = x.view(-1).kthvalue(k).values
low_p = kthvalue(x.view(-1), k=k)[0]
if high_q is not None:
k = int(math.floor(.01 * high_q * x.numel() + 0.5))
high_p = x.view(-1).kthvalue(k).values
high_p = kthvalue(x.view(-1), k=k)[0]
return low_p, high_p

def test_negative_percentile(self):
4 changes: 3 additions & 1 deletion tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,8 @@
from brevitas.graph.calibrate import calibration_mode
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFixedPoint
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue
from tests.brevitas.hyp_helper import float_tensor_random_size_st

IN_CH = 8
@@ -21,7 +23,7 @@

def compute_quantile(x, q):
k = int(math.floor(.01 * q * x.numel() + 0.5))
result = x.abs().view(-1).kthvalue(k).values
result = kthvalue(x.abs().view(-1), k=k)[0]
return result