From 1608fdfc615053541f34ee5207c5c9783464184e Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 14 Nov 2024 10:57:36 -0800 Subject: [PATCH] metal lowbit kernels: check contiguity of scales and zeros Summary: Check contiguity of scales and zero points tensors Polish generation of zero points tensor in kernel/op tests Differential Revision: D65957327 --- torchao/experimental/kernels/mps/test/test_lowbit.mm | 3 ++- torchao/experimental/ops/mps/register.mm | 2 ++ torchao/experimental/ops/mps/test/test_lowbit.py | 8 ++++---- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/torchao/experimental/kernels/mps/test/test_lowbit.mm b/torchao/experimental/kernels/mps/test/test_lowbit.mm index 398af237ae..2d86223034 100644 --- a/torchao/experimental/kernels/mps/test/test_lowbit.mm +++ b/torchao/experimental/kernels/mps/test/test_lowbit.mm @@ -101,7 +101,8 @@ void init() { int32_t ceil_K_group_size = (K + qGroupSize - 1) / qGroupSize; for (int idx = 0; idx < N * ceil_K_group_size; ++idx) { s_ptr[idx] = (idx + 1.0) / N; - z_ptr[idx] = int_distrib(generator); + auto zp = int_distrib(generator); + z_ptr[idx] = -zp * s_ptr[idx]; } for (int idx = 0; idx < M * N; ++idx) { c_ptr[idx] = -1.0; diff --git a/torchao/experimental/ops/mps/register.mm b/torchao/experimental/ops/mps/register.mm index a53f55d3d8..44946a30f0 100644 --- a/torchao/experimental/ops/mps/register.mm +++ b/torchao/experimental/ops/mps/register.mm @@ -58,6 +58,7 @@ void check_linear_mps_args( ": expect S to be 2d tensor with shape [:, ", N, "]"); + TORCH_CHECK(S.is_contiguous(), __func__, " : expect S to be contiguous."); TORCH_CHECK( Z.dim() == 2 && Z.size(1) == N, @@ -65,6 +66,7 @@ void check_linear_mps_args( ": expect Z to be 2d tensor with shape [:, ", N, "]"); + TORCH_CHECK(Z.is_contiguous(), __func__, " : expect Z to be contiguous."); } template diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py index f2d9d9c175..797c5dac29 100644 --- a/torchao/experimental/ops/mps/test/test_lowbit.py +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -46,18 +46,18 @@ class TestLowBitQuantWeightsLinear(unittest.TestCase): ] def _init_tensors(self, group_size, M, K, N, nbit, device="mps"): - max_abs = 1 << (nbit - 1) ceil_K_group_size = (K + group_size - 1) // group_size - A = 2 * torch.rand(M, K, dtype=torch.float32, device=device) - 1 - W = torch.randint(0, 2 * max_abs, (N, K), dtype=torch.uint8, device=device) + A = torch.rand(M, K, dtype=torch.float32, device=device) + W = torch.randint(0, 1 << nbit, (N, K), dtype=torch.uint8, device=device) S = torch.rand(ceil_K_group_size, N, dtype=torch.float32, device=device) + 0.01 Z = torch.randint( 0, - 2 * max_abs, + 1 << nbit, (ceil_K_group_size, N), dtype=torch.float32, device=device, ) + Z = -Z * S return A, W, S, Z def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit):