From 5eb6339e0b6f413c74a3dfd5e7f53449474723fc Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:30:34 -0800 Subject: [PATCH] Enable 8-bit (#1254) Enable 8-bit (#1254) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/1254 Enables 8-bit kernel in operators and tests Reviewed By: digantdesai Differential Revision: D65688410 --- .../_linear_8bit_act_xbit_weight_layout.py | 4 +-- .../embedding_xbit/op_embedding_xbit_aten.cpp | 3 ++ .../op_embedding_xbit_executorch.cpp | 1 + .../op_linear_8bit_act_xbit_weight_aten.cpp | 3 ++ .../w8s.cpp | 29 +++++++++++++++++++ .../w8sz.cpp | 29 +++++++++++++++++++ torchao/experimental/quant_api.py | 10 +++---- .../tests/test_embedding_xbit_quantizer.py | 2 +- ...t_linear_8bit_act_xbit_weight_quantizer.py | 2 +- ...dynamic_activation_intx_weight_subclass.py | 2 +- 10 files changed, 75 insertions(+), 10 deletions(-) create mode 100644 torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8s.cpp create mode 100644 torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8sz.cpp diff --git a/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py b/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py index 0e6c73343f..97e6380f92 100644 --- a/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py +++ b/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py @@ -65,7 +65,7 @@ def __init__( group_size: int, target: str, ): - assert nbit <= 7 + assert nbit <= 8 self.nbit = nbit self.group_size = group_size self.target = target_from_str(target) @@ -182,7 +182,7 @@ def from_plain( # Fallback assert layout.target == Target.FALLBACK - packed_weight = int_data.to(torch.int8) + packed_weight = int_data.to(torch.int32) return cls(packed_weight, scale, zero_point, layout) def _apply_fn_to_data(self, fn): diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp index dfb61eb928..1b019609a6 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp +++ b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp @@ -36,6 +36,7 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { DEFINE_OP(5); DEFINE_OP(6); DEFINE_OP(7); + DEFINE_OP(8); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { @@ -46,6 +47,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) { DEFINE_CPU_IMPL(5); DEFINE_CPU_IMPL(6); DEFINE_CPU_IMPL(7); + DEFINE_CPU_IMPL(8); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { @@ -56,4 +58,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) { DEFINE_META_IMPL(5); DEFINE_META_IMPL(6); DEFINE_META_IMPL(7); + DEFINE_META_IMPL(8); } diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp index 1b79a5e035..f99a575cfe 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp +++ b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp @@ -37,3 +37,4 @@ DEFINE_OP(4); DEFINE_OP(5); DEFINE_OP(6); DEFINE_OP(7); +DEFINE_OP(8); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp index f69e51e4c9..24d4008969 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp @@ -68,6 +68,7 @@ TORCH_LIBRARY(torchao, m) { DEFINE_OP(5); DEFINE_OP(6); DEFINE_OP(7); + DEFINE_OP(8); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { @@ -78,6 +79,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) { DEFINE_CPU_IMPL(5); DEFINE_CPU_IMPL(6); DEFINE_CPU_IMPL(7); + DEFINE_CPU_IMPL(8); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { @@ -88,4 +90,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) { DEFINE_META_IMPL(5); DEFINE_META_IMPL(6); DEFINE_META_IMPL(7); + DEFINE_META_IMPL(8); } diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8s.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8s.cpp new file mode 100644 index 0000000000..5257611d97 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8s.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow on +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +namespace { +Tensor _op_out( + RuntimeContext& ctx, + const Tensor& activations, + const Tensor& packed_weights, + const Tensor& group_size_tensor, + const Tensor& n_tensor, + const Tensor& k_tensor, + Tensor& out) { + (void)ctx; + linear_out_cpu( + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); + return out; +} +} // namespace + +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_8bit0zp_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8sz.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8sz.cpp new file mode 100644 index 0000000000..e26da69d67 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8sz.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow on +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +namespace { +Tensor _op_out( + RuntimeContext& ctx, + const Tensor& activations, + const Tensor& packed_weights, + const Tensor& group_size_tensor, + const Tensor& n_tensor, + const Tensor& k_tensor, + Tensor& out) { + (void)ctx; + linear_out_cpu( + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); + return out; +} +} // namespace + +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_8bit_weight.out", _op_out); diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index fc4fb341fd..be72a59aab 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -202,7 +202,7 @@ def forward(self, x): def _maybe_get_quantized_linear_native(nbit, has_weight_zeros): try: - if nbit in [1, 2, 3, 4, 5, 6, 7]: + if nbit in [1, 2, 3, 4, 5, 6, 7, 8]: wzp_suffix = "" if has_weight_zeros else "0zp" return _Int8DynActIntxWeightQuantizedLinearNative( pack_weight_op=getattr( @@ -234,7 +234,7 @@ def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): has_weight_zeros = kwargs["has_weight_zeros"] assert not isinstance(module, nn.Linear) - assert nbit >= 1 and nbit <= 7 + assert nbit >= 1 and nbit <= 8 for name, child in module.named_children(): if not isinstance(child, nn.Linear): @@ -370,9 +370,9 @@ def quantize_and_pack_weights(self, weights, group_size): weight_qvals, weight_scales, weight_zeros = _quantize( weights, self.group_size, self.nbit, has_weight_zeros=True ) - self.weight_qvals = weight_qvals.to(torch.int8) + self.weight_qvals = weight_qvals.to(torch.int32) self.weight_scales = weight_scales - self.weight_zeros = weight_zeros.to(torch.int8) + self.weight_zeros = weight_zeros.to(torch.int32) def forward(self, x): shape = x.shape @@ -398,7 +398,7 @@ def _replace_embedding_with_quantized_embedding(module: nn.Module, kwargs={}): nbit = kwargs["nbit"] assert not isinstance(module, nn.Embedding) - assert nbit >= 1 and nbit <= 7 + assert nbit >= 1 and nbit <= 8 for name, child in module.named_children(): if not isinstance(child, nn.Embedding): diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/torchao/experimental/tests/test_embedding_xbit_quantizer.py index a6ba04b439..0eccf33cdb 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/torchao/experimental/tests/test_embedding_xbit_quantizer.py @@ -65,7 +65,7 @@ def test_accuracy(self): model = torch.nn.Sequential(*[torch.nn.Embedding(num_embeddings, embedding_dim)]) indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32) - for nbit in [1, 2, 3, 4, 5, 6, 7]: + for nbit in [1, 2, 3, 4, 5, 6, 7, 8]: print(f"Testing nbit={nbit}") quantized_model = copy.deepcopy(model) quantizer = IntxWeightEmbeddingQuantizer( diff --git a/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py index 5d2828d9bc..aeb19555d7 100644 --- a/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py +++ b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py @@ -67,7 +67,7 @@ def test_accuracy(self): activations = torch.randn(2, 3, m, k, dtype=torch.float32) model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) - for nbit in [1, 2, 3, 4, 5, 6, 7]: + for nbit in [1, 2, 3, 4, 5, 6, 7, 8]: for has_weight_zeros in [True, False]: print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") quantized_model = copy.deepcopy(model) diff --git a/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py b/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py index 44e63386ce..d9035cbe3f 100644 --- a/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py +++ b/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py @@ -70,7 +70,7 @@ def test_accuracy(self): activations = torch.randn(m, k, dtype=torch.float32) model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) - for nbit in [1, 2, 3, 4, 5, 6, 7]: + for nbit in [1, 2, 3, 4, 5, 6, 7, 8]: for has_weight_zeros in [True, False]: print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") quantized_model = copy.deepcopy(model)