From ea7910e5c24523ea901aabe7945ce7ac0ffa1033 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= <115986737+alexsamardzic@users.noreply.github.com> Date: Tue, 21 Jan 2025 21:30:15 +0100 Subject: [PATCH 01/15] Refactor s8s4_linear_cutlass() (#1545) Refactor CUTLASS-based code so it could support operators other than W4A8 --- .../s8s4_linear_cutlass.cu | 489 ++++++++++-------- 1 file changed, 267 insertions(+), 222 deletions(-) diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu index 2daefb7773..411343f0da 100644 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu @@ -29,26 +29,35 @@ namespace torchao { #if defined(BUILD_S8S4_LINEAR_CUTLASS) template< - typename ElementA, - typename ElementAScale, - typename ElementB, - typename ElementBScale, - typename ElementC, - typename ElementAccumulator, - typename ElementEpilogue, - typename ElementOutput, typename ThreadblockShape, typename WarpShape, typename InstructionShape, int NumStages, - bool use_tensor_c> -void s8s4_linear_kernel_cutlass( + typename ElementA, + typename ElementB, + typename ElementAccumulator, + typename Operator, + typename ElementAScale, + typename ElementBScale, + typename ElementC, + typename UseTensorC, + typename ElementOutput> +void s8s4_linear_kernel_cutlass_sm8x( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, const at::Tensor& tensor_c, at::Tensor& tensor_d) { + using SmArch = cutlass::arch::Sm80; + using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::RowMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementEpilogue = float; + + using ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; + + constexpr auto NumEVTEpilogueStages = 1; const int m = tensor_a.size(0); const int n = tensor_b.size(0); @@ -56,13 +65,13 @@ void s8s4_linear_kernel_cutlass( constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentAScale = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentBScale = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentOutput = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; // Check for current CUTLASS limitations w.r.t. alignments. TORCH_CHECK(k % AlignmentA == 0, @@ -75,12 +84,6 @@ void s8s4_linear_kernel_cutlass( __func__, " : Number of columns of tensor C must be divisible ", "by ", AlignmentC); - using SmArch = cutlass::arch::Sm80; - using ThreadblockSwizzle = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; - - constexpr auto NumEVTEpilogueStages = 1; - using TensorAScaleTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< ThreadblockShape, @@ -132,9 +135,9 @@ void s8s4_linear_kernel_cutlass( cutlass::epilogue::threadblock::VisitorRowBroadcast< TensorCTileThreadMap, ElementC, - cute::Stride>; + cute::Stride>; using TensorC = - std::conditional_t; + std::conditional_t; using TensorCArguments = typename TensorC::Arguments; using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute< @@ -178,7 +181,7 @@ void s8s4_linear_kernel_cutlass( typename cutlass::gemm::kernel::DefaultGemmWithVisitor< ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, - ElementC, LayoutC, AlignmentC, + ElementOutput, LayoutOutput, AlignmentOutput, ElementAccumulator, ElementEpilogue, cutlass::arch::OpClassTensorOp, @@ -189,7 +192,7 @@ void s8s4_linear_kernel_cutlass( EVTOutput, ThreadblockSwizzle, NumStages, - cutlass::arch::OpMultiplyAddMixedInputUpcast, + Operator, NumEVTEpilogueStages >::GemmKernel; @@ -210,7 +213,7 @@ void s8s4_linear_kernel_cutlass( }; TensorCArguments tensor_c_arguments{ [&]() -> TensorCArguments { - if constexpr (use_tensor_c) { + if constexpr (UseTensorC::value) { return {(ElementC*)tensor_c.data_ptr(), ElementC(0), {cute::_0{}, cute::_1{}, problem_size.n()}}; @@ -282,127 +285,193 @@ void s8s4_linear_kernel_cutlass( // Perform mixed datatypes GEMM operation. status = gemm_op.run(at::cuda::getCurrentCUDAStream()); CUTLASS_STATUS_CHECK(status); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); } -template< - typename ElementA, - typename ElementAScale, - typename ElementB, - typename ElementBScale, - typename ElementC, - typename ElementAccumulator, - typename ElementEpilogue, - typename ElementOutput, - bool use_tensor_c> -void -s8s4_linear_cutlass_dispatch_shapes( +template +static void select_config( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm8x = dprops->major == 8; + + if (is_sm8x) { + if constexpr (std::is_same::value && + std::is_same::value) { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + // A minimal heuristic to improve performance for small number + // of inputs cases. + if (tensor_a.size(0) <= 16) { + using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; + constexpr auto NumStages = 6; + s8s4_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, + ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else if (tensor_a.size(0) <= 32) { + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; + constexpr auto NumStages = 5; + s8s4_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, + ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + constexpr auto NumStages = 4; + s8s4_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, + ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } + return; + } + } + + TORCH_CHECK(false, + __func__, " : Operator not supported on SM", dprops->major, ".", + dprops->minor, " for given operands"); +} + +template +static void +dispatch_on_tensor_a_and_tensor_b( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, const at::Tensor& tensor_c, at::Tensor& tensor_d) { - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; - - // A minimal heuristic to improve performance for small number of - // inputs cases. - if (tensor_a.size(0) <= 16) { - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; - constexpr auto NumStages = 6; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else if (tensor_a.size(0) <= 32) { - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; - constexpr auto NumStages = 5; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; - constexpr auto NumStages = 4; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( + if (tensor_a.scalar_type() == at::ScalarType::Char) { + if (tensor_b.scalar_type() == at::ScalarType::Char) { + if (tensor_a.size(1) == 2 * tensor_b.size(1)) { + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast; + select_config< + ElementA, ElementB, ElementAccumulator, Operator, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + } + return; + } } + + TORCH_CHECK(false, + __func__, " : Operator not supported for combination of data ", + "types ", tensor_a.scalar_type(), " for first operand and ", + tensor_b.scalar_type(), " for second operand"); } -#endif -// Perform linear operation, using corresponding CUTLASS mixed -// data-types GEMM kernel, to given arguments: -// result = (input * input_scale) @ (weight * weight_scale).T + bias -// Notes: The "input_scale" tensor is expected to be a vector, of size -// equal to number of rows of "input" tensor. The "weight_scale" -// tensor is expected to be a vector, of size equal to number of rows -// of "weight" tensor. The "bias" tensor is expected to be a vector, -// of size equal to number of rows of "weight" tensor. -at::Tensor -s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, - const at::Tensor& weight, const at::Tensor& weight_scale, - const at::Tensor& bias) { -#if defined(BUILD_S8S4_LINEAR_CUTLASS) - // For now, only CC 8.x devices are supported. - const auto dprops = at::cuda::getCurrentDeviceProperties(); - const auto is_sm8x = dprops->major == 8; - TORCH_CHECK(is_sm8x, - __func__, " : Supported only on GPUs with compute capability " - "8.x"); - - // Validate datatypes of arguments. - TORCH_CHECK(input.dtype() == at::kChar, - __func__, " : The input datatype ", input.dtype(), - " not supported"); - TORCH_CHECK(input_scale.dtype() == at::kHalf || - input_scale.dtype() == at::kBFloat16, - __func__, " : The input scale datatype ", input_scale.dtype(), - " not supported"); - TORCH_CHECK(weight.dtype() == at::kChar, " : The weight datatype ", - weight.dtype(), " not supported"); - TORCH_CHECK(weight_scale.dtype() == input_scale.dtype(), - __func__, " : Expected weight scale datatype ", - input_scale.dtype(), ", got ", weight_scale.dtype()); - if (bias.numel() > 0) { - TORCH_CHECK(bias.dtype() == input_scale.dtype(), - __func__, " : Expected bias datatype ", input_scale.dtype(), - ", got ", bias.dtype()); +template +static void +dispatch_on_tensor_c( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + if (tensor_c.numel() == 0) { + using ElementC = ElementOutput; + using UseTensorC = std::false_type; + dispatch_on_tensor_a_and_tensor_b< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } + + using UseTensorC = std::true_type; + if (tensor_c.scalar_type() == at::ScalarType::Half) { + using ElementC = cutlass::half_t; + dispatch_on_tensor_a_and_tensor_b< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) { + using ElementC = cutlass::bfloat16_t; + dispatch_on_tensor_a_and_tensor_b< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; } + TORCH_CHECK(false, + __func__, " : Operator not supported for datatype ", + tensor_c.scalar_type(), " for addend"); +} + +static void +dispatch_on_tensor_a_scale_and_tensor_b_scale( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(), + __func__, " : Operator not supported for output datatype ", + tensor_d.scalar_type(), " as it's different from the first ", + " operand scale datatype ", tensor_a_scale.scalar_type()); + + if (tensor_a_scale.scalar_type() == at::ScalarType::Half && + tensor_b_scale.scalar_type() == at::ScalarType::Half) { + using ElementAScale = cutlass::half_t; + using ElementBScale = cutlass::half_t; + using ElementOutput = cutlass::half_t; + dispatch_on_tensor_c( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + return; + } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 && + tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) { + using ElementAScale = cutlass::bfloat16_t; + using ElementBScale = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + dispatch_on_tensor_c( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + return; + } + + TORCH_CHECK(false, + __func__, " : Operator not supported for combination of data ", + "types ", tensor_a_scale.scalar_type(), + " for first operand scale and ", tensor_b_scale.scalar_type(), + " for second operand scale"); +} + +void +check_inputs( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { // Validate layouts of arguments. - TORCH_CHECK(input.dim() >= 2, - __func__, " : Expected input argument to be 2D or " - "higher-dimensional tensor, got ", input.dim(), " dims"); - TORCH_CHECK(input.layout() == at::Layout::Strided, - __func__, " : Expected input argument to be strided, got layout ", - input.layout()); - TORCH_CHECK(input_scale.dim() == input.dim() - 1, - __func__, " : Expected input scale argument to be ", - input.dim() - 1, "D tensor, got ", input_scale.dim(), " dims"); - TORCH_CHECK(input_scale.layout() == at::Layout::Strided, - __func__, " : Expected input scale argument to be strided, got " - "layout ", input_scale.layout()); - TORCH_CHECK(weight.dim() == 2, - __func__, " : Expected weight argument to be 2D tensor, got ", - weight.dim(), " dims"); - TORCH_CHECK(weight.layout() == at::Layout::Strided, - __func__, - " : Expected weight argument to be strided, got layout ", - weight.layout()); - TORCH_CHECK(weight_scale.dim() == 1 || weight_scale.dim() == 2, - __func__, " : Expected weight scale argument to be 1D or 2D ", - "tensor, got ", weight_scale.dim(), " dims"); - TORCH_CHECK(weight_scale.layout() == at::Layout::Strided, - __func__, " : Expected weight scale argument to be strided, got " - "layout ", weight_scale.layout()); + TORCH_CHECK(xq.dim() >= 2, + __func__, " : Expected xq argument to be 2D or " + "higher-dimensional tensor, got ", xq.dim(), " dims"); + TORCH_CHECK(xq.layout() == at::Layout::Strided, + __func__, " : Expected xq argument to be strided, got layout ", + xq.layout()); + TORCH_CHECK(x_scale.dim() == xq.dim() - 1, + __func__, " : Expected xq scale argument to be ", xq.dim() - 1, + "D tensor, got ", x_scale.dim(), " dims"); + TORCH_CHECK(x_scale.layout() == at::Layout::Strided, + __func__, " : Expected xq scale argument to be strided, got " + "layout ", x_scale.layout()); + TORCH_CHECK(wq.dim() == 2, + __func__, " : Expected wq argument to be 2D tensor, got ", + wq.dim(), " dims"); + TORCH_CHECK(wq.layout() == at::Layout::Strided, + __func__, " : Expected wq argument to be strided, got layout ", + wq.layout()); + TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2, + __func__, " : Expected wq scale argument to be 1D or 2D tensor, ", + "got ", w_scale.dim(), " dims"); + TORCH_CHECK(w_scale.layout() == at::Layout::Strided, + __func__, " : Expected wq scale argument to be strided, got " + "layout ", w_scale.layout()); if (bias.numel() > 0) { TORCH_CHECK(bias.dim() == 1, __func__, " : Expected bias argument to be 1D tensor, got ", @@ -412,116 +481,92 @@ s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, "layout ", bias.layout()); } - // Squash the input tensor to 2D tensor. - const auto input_sizes = input.sizes().vec(); - const auto input_2d = input.reshape({-1, input_sizes.back()}); - const auto input_scale_sizes = input_scale.sizes().vec(); - const auto input_scale_1d = input_scale.reshape({-1}); - const auto weight_scale_1d = weight_scale.reshape({-1}); - // Validate sizes of arguments. - TORCH_CHECK(input_2d.size(1) == 2 * weight.size(1), - __func__, " : Expected input argument to have ", - 2 * weight.size(1), " columns, but got ", input_2d.size(1)); - for (auto i = 0; i < input_scale_sizes.size(); ++i) - TORCH_CHECK(input_scale_sizes[i] == input_sizes[i], - __func__, " : Expected input scale argument size at position ", - i, " to be ", input_sizes[i], ", but got ", - input_scale_sizes[i]); - TORCH_CHECK(weight_scale_1d.numel() == weight.size(0), - __func__, " : Expected weight scale argument to have ", - weight.size(0), " elements, got ", weight_scale_1d.numel(), - " elements"); + const auto xq_sizes = xq.sizes().vec(); + TORCH_CHECK(xq_sizes.back() == 2 * wq.size(1), + __func__, " : Expected xq argument to have ", 2 * wq.size(1), + " columns, but got ", xq_sizes.back()); + const auto x_scale_sizes = x_scale.sizes().vec(); + for (auto i = 0; i < x_scale_sizes.size(); ++i) + TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i], + __func__, " : Expected xq scale argument size at position ", + i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]); + TORCH_CHECK(w_scale.numel() == wq.size(0), + __func__, " : Expected wq scale argument to have ", wq.size(0), + " elements, got ", w_scale.numel(), " elements"); if (bias.numel() > 0) { - TORCH_CHECK(bias.numel() == weight.size(0), - __func__, " : Expected bias argument to have ", weight.size(0), + TORCH_CHECK(bias.numel() == wq.size(0), + __func__, " : Expected bias argument to have ", wq.size(0), " elements, got ", bias.numel(), " elements"); } // Validate strides of arguments. - const auto input_2d_strides = input_2d.strides(); - TORCH_CHECK(input_2d_strides[0] >= 1 && input_2d_strides[1] == 1, - __func__, " : Expected input argument in row-major layout"); - const auto input_scale_1d_strides = input_scale_1d.strides(); - TORCH_CHECK(input_scale_1d_strides[0] == 1, - __func__, " : Expected input scale argument to be contiguous"); - const auto weight_strides = weight.strides(); - TORCH_CHECK(weight_strides[0] >= 1 && weight_strides[1] == 1, - __func__, " : Expected weight argument in row-major layout"); - const auto weight_scale_1d_strides = weight_scale_1d.strides(); - TORCH_CHECK(weight_scale_1d_strides[0] == 1, - __func__, " : Expected weight scale argument to be contiguous"); + const auto xq_strides = xq.strides(); + TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1, + __func__, " : Expected xq argument in row-major layout"); + auto xq_stride_expected = xq_strides[xq_strides.size() - 2]; + for (int i = xq_strides.size() - 3; i >= 0; --i) { + xq_stride_expected *= xq_sizes[i + 1]; + TORCH_CHECK(xq_strides[i] == xq_stride_expected, + __func__, " : Expected xq argument in row-major layout"); + } + TORCH_CHECK(x_scale.is_contiguous(), + __func__, " : Expected xq scale argument to be contiguous"); + const auto wq_strides = wq.strides(); + TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1, + __func__, " : Expected wq argument in row-major layout"); + TORCH_CHECK(w_scale.is_contiguous(), + __func__, " : Expected wq scale argument to be contiguous"); if (bias.numel() > 0) { const auto bias_strides = bias.strides(); TORCH_CHECK(bias_strides[0] == 1, __func__, " : Expected bias argument to be contiguous"); } +} +#endif + +// Perform linear operation, using corresponding CUTLASS mixed +// data-types GEMM kernel, to given arguments: +// result = (xq * x_scale) @ (wq * w_scale).T + bias +// Notes: The "x_scale" tensor is expected to be a vector, of size +// equal to number of rows of "xq" tensor. The "w_scale" tensor is +// expected to be a vector, of size equal to number of rows of "wq" +// tensor. The "bias" tensor is expected to be a vector, of size equal +// to number of rows of "wq" tensor. +at::Tensor +s8s4_linear_cutlass( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { +#if defined(BUILD_S8S4_LINEAR_CUTLASS) + // Check inputs. + check_inputs(xq, x_scale, wq, w_scale, bias); + + // Squash the input tensors as appropriate. + const auto xq_sizes = xq.sizes().vec(); + const auto xq_2d = xq.reshape({-1, xq_sizes.back()}); + const auto x_scale_sizes = x_scale.sizes().vec(); + const auto x_scale_1d = x_scale.reshape({-1}); + const auto w_scale_1d = w_scale.reshape({-1}); // Introduce alias names for arguments, according to the CUTLASS // naming conventions. - const auto& tensor_a = input_2d; - const auto& tensor_a_scale = input_scale_1d; - const auto& tensor_b = weight; - const auto& tensor_b_scale = weight_scale_1d; + const auto& tensor_a = xq_2d; + const auto& tensor_a_scale = x_scale_1d; + const auto& tensor_b = wq; + const auto& tensor_b_scale = w_scale_1d; const auto& tensor_c = bias; // Create output tensor. at::Tensor tensor_d = tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); - using ElementA = int8_t; - using ElementB = cutlass::int4b_t; - using ElementAccumulator = int32_t; - AT_DISPATCH_SWITCH( - input_scale.scalar_type(), - "s8s4_linear_cutlass", - AT_DISPATCH_CASE( - at::ScalarType::Half, - [&]() { - using ElementAScale = cutlass::half_t; - using ElementBScale = cutlass::half_t; - using ElementC = cutlass::half_t; - using ElementEpilogue = float; - using ElementOutput = cutlass::half_t; - if (bias.numel() > 0) { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, true>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, false>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } - }) - AT_DISPATCH_CASE( - at::ScalarType::BFloat16, - [&]() { - using ElementAScale = cutlass::bfloat16_t; - using ElementBScale = cutlass::bfloat16_t; - using ElementC = cutlass::bfloat16_t; - using ElementEpilogue = float; - using ElementOutput = cutlass::bfloat16_t; - if (bias.numel() > 0) { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, true>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, false>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } - })); - - auto tensor_d_sizes = input_sizes; - tensor_d_sizes.back() = weight.size(0); + // Dispatch to appropriate kernel template. + dispatch_on_tensor_a_scale_and_tensor_b_scale( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + + // Reshape and return output tensor. + auto tensor_d_sizes = xq_sizes; + tensor_d_sizes.back() = wq.size(0); return tensor_d.reshape(tensor_d_sizes); #else TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); From 5d1444bdef6df15eb89c4c5716ede1c5f8677798 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 21 Jan 2025 15:22:03 -0800 Subject: [PATCH 02/15] Sparsity docs update (#1590) --- docs/source/api_ref_sparsity.rst | 6 +++--- torchao/sparsity/sparse_api.py | 32 ++++++++++++++++---------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/source/api_ref_sparsity.rst b/docs/source/api_ref_sparsity.rst index 8023d0bacc..33c652390d 100644 --- a/docs/source/api_ref_sparsity.rst +++ b/docs/source/api_ref_sparsity.rst @@ -12,7 +12,7 @@ torchao.sparsity WandaSparsifier PerChannelNormObserver - apply_sparse_semi_structured apply_fake_sparsity - - + sparsify_ + semi_sparse_weight + int8_dynamic_activation_int8_semi_sparse_weight diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 3dd7971525..eb31cba619 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -43,7 +43,7 @@ def sparsify_( apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, ) -> torch.nn.Module: - """Convert the weight of linear modules in the model with `apply_tensor_subclass` + """Convert the weight of linear modules in the model with `apply_tensor_subclass`. This function is essentially the same as quantize, put for sparsity subclasses. Currently, we support three options for sparsity: @@ -54,26 +54,26 @@ def sparsify_( Args: model (torch.nn.Module): input model apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance) - filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on - the weight of the module + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module - Example:: - import torch - import torch.nn as nn - from torchao.sparsity import sparsify_ + **Example:** + :: + import torch + import torch.nn as nn + from torchao.sparsity import sparsify_ - def filter_fn(module: nn.Module, fqn: str) -> bool: - return isinstance(module, nn.Linear) + def filter_fn(module: nn.Module, fqn: str) -> bool: + return isinstance(module, nn.Linear) - m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) + m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - # for 2:4 sparsity - from torchao.sparse_api import semi_sparse_weight - m = sparsify_(m, semi_sparse_weight(), filter_fn) + # for 2:4 sparsity + from torchao.sparse_api import semi_sparse_weight + m = sparsify_(m, semi_sparse_weight(), filter_fn) - # for int8 dynamic quantization + 2:4 sparsity - from torchao.dtypes import SemiSparseLayout - m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) + # for int8 dynamic quantization + 2:4 sparsity + from torchao.dtypes import SemiSparseLayout + m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) """ _replace_with_custom_fn_if_matches_filter( model, From 166a35768a60964a2415be9823d800b24ed00cf3 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Wed, 22 Jan 2025 15:46:03 -0800 Subject: [PATCH 03/15] Sparsity getting started docs (#1592) --- docs/source/index.rst | 95 +---- docs/source/sparsity.rst | 731 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 744 insertions(+), 82 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index c008c80453..3bbcd203fd 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -3,80 +3,25 @@ Welcome to the torchao Documentation `torchao `__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README `__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on: -1. API Reference -2. Developer Contribution Guide -3. Tutorials +1. Getting Started +2. Developer Notes +3. API Reference +4. Tutorials -.. - .. grid:: 3 - - .. grid-item-card:: :octicon:`file-code;1em` - Getting Started - :img-top: _static/img/card-background.svg - :link: getting-started.html - :link-type: url - - Learn about how to get started with torchao - and ts application in your projects. - - .. grid-item-card:: :octicon:`file-code;1em` - Concepts - :img-top: _static/img/card-background.svg - :link: dtypes.html - :link-type: url - - Learn about the key torchao concepts such - as dtypes, quantization, sparsity, among others. - - .. grid-item-card:: :octicon:`file-code;1em` - API Reference - :img-top: _static/img/card-background.svg - :link: api_ref_intro.html - :link-type: url - - A comprehensive reference for the torchao - API and its functionalities. - - Tutorials - ~~~~~~~~~ - - Ready to experiment? Check out some of the - torchao tutorials. - - .. customcardstart:: - - .. customcarditem:: - :header: Template Tutorial - :card_description: A placeholder template for demo purposes - :image: _static/img/generic-pytorch-logo.png - :link: tutorials/template_tutorial.html - :tags: template - - .. customcardend:: - - -.. ---------------------------------------------------------------------- -.. Below is the toctree i.e. it defines the content of the left sidebar. -.. Each of the entry below corresponds to a file.rst in docs/source/. -.. ---------------------------------------------------------------------- - -.. - .. toctree:: - :glob: - :maxdepth: 1 - :caption: Getting Started - :hidden: +.. toctree:: + :glob: + :maxdepth: 1 + :caption: Getting Started - overview - getting-started + getting-started + sparsity - .. toctree:: - :glob: - :maxdepth: 1 - :caption: Tutorials - :hidden: +.. toctree:: + :glob: + :maxdepth: 1 + :caption: Developer Notes - tutorials/template_tutorial + contributor_guide .. toctree:: :glob: @@ -86,15 +31,6 @@ Welcome to the torchao Documentation api_ref_dtypes api_ref_quantization api_ref_sparsity -.. - api_ref_kernel - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Contributor Guide - - contributor_guide .. toctree:: :glob: @@ -102,4 +38,3 @@ Welcome to the torchao Documentation :caption: Tutorials serialization - diff --git a/docs/source/sparsity.rst b/docs/source/sparsity.rst index 273ee5b770..0bde173b6d 100644 --- a/docs/source/sparsity.rst +++ b/docs/source/sparsity.rst @@ -1,4 +1,731 @@ Sparsity -======== +-------- -TBA +Sparsity is the technique of removing parameters from a neural network in order to reduce its memory overhead or latency. By carefully choosing how the elements are pruned, one can achieve significant reduction in memory overhead and latency, while paying a reasonably low or no price in terms of model quality (accuracy / f1). + +Goal +==== + +We feel that the main problem current sparsity researchers / users face is fragmentation. Researchers rightfully aim to show end-to-end results, but this means a lot of time is spent figuring out how to integrate with PyTorch and implementation questions like: + + +* *When should I mask?* +* *When/how should I store the compressed representation?* +* *Do I want in-place or out-of-place mask updates?* +* *How can I call sparse matmul instead of dense?* + +We feel like the above problems can be solved once by ``torchao``\ , letting researchers focus on what really matters - pushing sparse kernel performance or more accurate pruning algorithms. + +More concretely, we hope to provide tutorials and APIs for both sparse kernels (tensor subclassing) and pruning algorithms (torch.ao.pruning.Sparsifier) that users can extend. We aim to provide modular building blocks, that can be used to accelerate not only inference but training as well, and that compose nicely with ``torchao`` quantization workflows. + + +#. Train sparse models from scratch with hardware acceleration, with minimal accuracy loss. +#. Recover accuracy loss of pruned model with custom pruning algorthim. +#. Accelerate masked/pruned models on sparsity-supported hardware to realize performance improvements. + +Design +====== + +Sparsity, like quantization, is an accuracy/performance trade-off, where we care not only about the speedup but also on the accuracy degradation of our architecture optimization technique. + +In quantization, the theoretical performance gain is generally determined by the data type that we are quantizing to - quantizing from float32 to float16 yields a theoretical 2x speedup. For pruning/sparsity, the analogous variable would be the sparsity level/ sparsity pattern. For semi-structured, the sparsity level is fixed at 50%, so we expect a theoretical 2x improvement. For block-sparse matrices and unstructured sparsity, the speedup is variable and depends on the sparsity level of the tensor. + +One key difference between sparsity and quantization is in how the accuracy degradation is determined: In general, the accuracy degradation of quantization is determined by the scale and zero_point chosen. However, in pruning the accuracy degradation is determined by the mask. Sparsity and quantization are closely related and share accuracy mitigation techniques like quantization/sparsity aware training. + +By carefully choosing the specified elements and retraining the network, pruning can achieve negligible accuracy degradation and in some cases even provide a slight accuracy gain. This is an active area of research with no agreed-upon consensus. We expect users will have a target sparsity pattern and mind and to prune to that pattern. + +Given a target sparsity pattern, pruning/sparsifying a model can then be thought of as two separate subproblems: + + +* **Accuracy** - How can I find a set of sparse weights which satisfy my target sparsity pattern that minimize the accuracy degradation of my model? +* **Perforance** - How can I accelerate my sparse weights for inference and reduce memory overhead? + +Our workflow is designed to consist of two parts that answer each question independently: + + +* a frontend python user-facing API to find sparse weights for any arbitrary sparsity pattern. +* a backend collection of sparse kernels / ops to reduce memory/latency. + +The handoff point between these two pieces are sparse weights stored in a dense format, with 0 in the place of missing elements. This is a natural handoff point because sparse matrix multiplication and dense matrix multiplication with this tensor will be numerically equivalent. This lets us present a clear contract to the user for our backend, for a given sparsity pattern: + +If you can get your dense matrix into a **2:4 sparse format**, we can speed up matrix multiplication up to **1.7x** with no numerical loss. + +This also allows users with existing sparse weights in a dense format to take advantage of our fast sparse kernels. We anticipate many users to come up with their own custom frontend masking solution or to use another third party solution, as this is an active area of research. + + +.. image:: ../static/pruning_ecosystem_diagram.png + :alt: pruning_flow + + +Below, we provide an example of accelerating a model with 2:4 sparsity + bf16 using our PyTorch APIs. + +.. code-block:: python + + import torch + from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor + from torch.ao.pruning import WeightNormSparsifier + + # bfloat16 CUDA model + model = model.half().cuda() + + # Accuracy: Finding a sparse subnetwork + sparse_config = [] + for name, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear): + sparse_config.append({"tensor_fqn": f"{name}.weight"}) + + sparsifier = WeightNormSparsifier(sparsity_level=1.0, + sparse_block_shape=(1,4), + zeros_per_block=2) + + # attach FakeSparsity + sparsifier.prepare(model, sparse_config) + sparsifier.step() + sparsifier.squash_mask() + # now we have dense model with sparse weights + + # Performance: Accelerated sparse inference + for name, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear): + mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight)) + +Fundamentally, the flow works by manipulating ``torch.Tensors``. In the frontend, we specify the tensors by their fully-qualified-name in a sparse_config dictionary. The frontend is designed to follow the quantization API, with a ``prepare`` function, which attaches FakeSparsity paramerizations to the tensors specified in the config. + +FakeSparsity is a parameterization which simulates unstructured sparsity, where each element has a mask. Because of this, we can use it to simulate any sparsity pattern we want. + +The user will then train the prepared model using their own custom code, calling ``.step()`` to update the mask if necessary. Once they’ve found a suitable mask, they call ``squash_mask()`` to fuse the mask into the weights, creating a dense tensor with 0s in the right spot. + +Users will then convert their model for accelerated sparse inference by either using the quantization flow for quantized block sparse CPU inference or by calling ``to_sparse_semi_structured`` on the specified weight tensors. + +Context +======= + +This section provides some context on neural network pruning/sparsity as well as definitions for some common pruning/sparsity terms. In academia / industry, **pruning** and **sparsity** are often used interchangeably to refer to the same thing. This can be confusing, especially since sparsity is an overloaded term that can refer to many other things, such as sparse tensor representations. + +Note that this section focuses on **pruning**, instead of **sparse training**. The distinction being that in **pruning** we start with a pretrained dense model, while during **sparse training** we train a sparse model from scratch. + +In order to avoid confusion, we generally try to use sparsity to refer to tensors. Note that a sparse tensor can refer to a dense tensor with many zero values, or a tensor stored using a sparse representation. We describe the flow as **pruning** and the resultant model as a **pruned** model. + +Roughly, the flow for achieving a more performant pruned model looks like this: + + +.. image:: ../static/pruning_flow.png + :alt: flow + + +The general idea behind pruning is that we can mask out some of the weights of a trained neural network and recover any accuracy loss. The resultant pruned model can be run on optimized kernels that take advantage of this sparsity for accelerated inference. + +Zeroing out pruned parameters doesn’t affect the latency / memory overhead of the model out of the box. This is because the dense tensor itself still contains the pruned elements (the 0 elements) and will still compute using those elements during a matrix multiply. In order to realize performance gains, we need to swap out our dense kernels for sparse kernels. + +Loosely speaking, these sparse representations allow us to skip calculations involving pruned elements in order to speed up matrix multiplication. To do this, these optimized sparse kernels work on sparse matrices that are stored in a more efficient format. Some sparse tensor layouts are tightly coupled to specific backends, like NVIDIA 2:4, while others are more general and are supported by more than one backend (CSC is supported by FBGEMM and QNNPACK). + + +.. raw:: html + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Name + Description + How the sparse matrix is stored +
COO (sparse_coo) + COOrdinate format to store sparse matrices. The matrices are stored as a combination of the non-sparse data vector and the index locations of those elements in the dense matrix. + sparse matrix = {Index: Tensor of coordinate locations, + Data: Tensor of values corresponding to index locations } +
BSR (sparse_bsr) + Block sparse row format to store sparse matrices. The matrices are stored as data blocks and the index locations of those blocks in the dense matrix. Very similar to COO, except that individual data consists of blocks, not scalars. + sparse matrix = {Index: Tensor of coordinate locations, two dimensional for a matrix, + Data: Tensor of blocks corresponding to index locations } + where a block is a matrix corresponding to the sparsity pattern. +
CSR (sparse_csr) / CSC (sparse_csc) + Compressed sparse row /column format to store sparse matrices. The sparse matrices are stored as data blocks on columns / rows and indices of those rows/columns in a dense matrix. This is the most compact format for storing block sparse matrices. + sparse_matrix = {Index: 1D tensor of column indices, + IndexPtr: 1D tensor specifying the start and end indices of columns for rows, starting from row 0, + Data: Tensor of blocks corresponding to Index locations.} +
NVIDIA 2:4 compressed representation + Custom NVIDIA compressed storage format for 2:4 semi-structured sparsity. We store the sparse matrix as a compressed dense matrix (½ the size) containing the non-pruned elements and a bitmask index. When multiplying our sparse matrix by another dense matrix, we use the mask to index into the dense matrix and multiply with our compressed dense matrix. + sparse_matrix = {Bitmask: 2bit indices of pruned elements Compressed dense matrix: contains all unpruned elements, half the size of original dense matrix} +
+ + +*Table 4.1: Overview of common sparse tensor layouts.* + +While the general idea of pruning is quite simple, there are many details that a user must figure out before they can successfully prune a model. + +These can be loosely broken down as follows: + + +* **Pruning Configuration** - What layers should I prune? What sparsity level should I prune to? +* **Pruning Criteria** - How should I decide which parameters to remove? +* **Pruning Strategy** - Once I have removed parameters, how can I recover any accuracy degradation? +* **Sparsity Pattern** - Should I try to use a specific sparsity pattern when I prune my model? Different hardware backends support accelerated inference for different sparsity patterns. + +Pruning Configuration +^^^^^^^^^^^^^^^^^^^^^ + +Not all layers in a neural network are created equal. Some layers can be more sensitive to pruning than others. The user must decide what layers to prune and also the **sparsity level** for each layer, which is the % of 0s for that weight tensor. The pruning configuration has an effect on both the accuracy and speedup of the pruned model. + +Determining the best pruning configuration and sparsity level for a given model is an open problem and a general solution does not exist. This is in part because the optimal pruning configuration is dependent on the subsequent pruning criteria and strategy, and there are an infinite number of ways to decide how to prune models and how to recover lost accuracy. + +One common method to determine which layers to prune and to what degree is to perform sensitivity analysis by pruning each layer in the model at different sparsity levels and seeing the subsequent accuracy drop (without retraining). This gives a user a sparsity-accuracy curve for each layer that the user can then use as a proxy to determine the best pruning configuration. + +Pruning Criteria +^^^^^^^^^^^^^^^^ + +A user must decide on a criteria for removing parameters from a neural network. Much like determining the best pruning configuration, determining the best pruning criteria is an open research question and is dependent on the other aforementioned factors. + +The most common pruning criteria is to use weight magnitude. The idea is that low-magnitude weights contribute less than high-magnitude weights to the model output. If we want to remove parameters, we can remove the weights that have the smallest absolute value. + +However, even with a simple pruning criteria such as weight magnitude, there are additional factors that a user would have to consider: + + +* Local vs global scope + + * **Local scope** implies that the sparsity mask is only computed with respect to the layer statistics. + + * Pros: Simple mask computing + * Cons: Potentially sub-optimal accuracy vs sparsity tradeoff. + + * **Global scope** means that the sparsity statistics are not bounded by a single layer, but can span over multiple layers if needed. + + * Pros: No need for per-layer thresholds. The tensor statistics is shared across layers, and normalization is used across layers to allow for it. + * Cons: Increased complexity when computing the masks. + +* Tensors used for mask calculation + + * **Weights**\ : Just use the weight tensor in order to calculate the mask. This method is the simplest for inference as the weight tensors are constant. + * **Gradients**\ : Compute importance based on both weights and gradient norms. Common for pre-training based methods. Currently CTR_mobile_feed uses a gradient-based pruning algorithm. + * **Activations**\ : In some research papers, the norm of the activations that are applied with the weight of interest are used to compute the importance score. + +* In place or out of place mask updates + + * **In-place** updates the sparse tensor by performing W = W (Mask). Once the weight tenosr is udpated, the sparse values are zeroed out and cannot be recovered. + + * **Pros**\ : Requires only one copy of the sparse tensor to be stored (+ mask) + * **Cons**\ : Once a mask is applied to a weight, it is zeroed out, all past history is lost. These weights cannot regrow. + + * **Out-of-place** updates don't modify the tensor directly, but perform the following: W' = W (Mask) and dW'= dW (Mask) + + * **Pros**\ : The original tensor is preserved (the masked elements are not updated via backprop). Weights can regrow if the mask changes. This is necessary for PAT. + * **Cons**\ : In addition to the unmasked weights (W), the masked weights (W’) are computed and resident in memory for forward/backward computations. + + +.. raw:: html + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Name + Description + Notes +
Magnitude / Saliency + Remove parameters that have the lowest norm (L1 is commonly used) + Shown to work well with 2:4 semi-structured sparsity. Able to achieve identical accuracy as the original model by repeating the training loop after one-shot magnitude pruning. +
Movement Pruning + These methods aim to use gradient information in order to decide what parameters to remove. The idea is to remove parameters that do not change much during fine-tuning. + Common for pretrained models. +

+ See https://arxiv.org/abs/2005.07683 +

Low-rank factorization + These methods aim to replace Wx with SQx, where S and Q are matrices with lower rank. + Usually these methods use some sort of layer-wise reconstruction, where instead of training the model to recover lost accuracy, they seek to match layer-wise statistics (Find SQx such that L2(SQx, Wx) is minimized). +
Random + Remove parameters randomly + +
+ + +*Table 4.2: Description of some common pruning criteria.* + +Pruning Strategy +^^^^^^^^^^^^^^^^ + +This is a general term that describes the method in which a user tries to recover any accuracy degradation from their pruned model. After pruning a model, it is common to see accuracy degradation of the model, so users usually retrain the pruned model in order to remediate this. The pruning strategy also determines when and how often the model is pruned during model training. + +The line between a pruning strategy and a pruning criteria is not well defined, especially in the case of pruning aware training methods, which update the mask during training. We sometimes use the term **pruning** **algorithm** to refer to the combination of these two items. These two factors, along with the pruning configuration ultimately control the final accuracy of the pruned model. + + +.. raw:: html + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Pruning Strategy + Description + Notes +
Zero-shot + Prune once, don’t retrain the model + These methods rely on more complicated pruning criteria. +

+ This is sometimes referred to as one-shot in literature, but we will use one-shot to refer to pruning once and retraining once. +

One-shot + Prune once, retrain the model once + NVIDIA has shown that one-shot 2:4 semi-structured sparsity pruning generalizes well across a range of common vision / nlp models. \ + \ + The retraining strategy is to simply repeat the training process again. +
Iterative + Prune the model, retrain, repeat + We can iteratively increase the sparsity level, or iteratively prune different layers in the model. +
Pruning Aware Training + Mask is learned during training + Used by CTR_feed for their current pruning algorithm. +
NAS / Multimask + Multiple masks are used during training. This can be thought of a form of neural architecture search. + Used by PySpeech (FastNAS) +
Layer-wise reconstruction + Instead of retraining using a loss function, we try to recover as much information as possible from each layer by using a two model approach similar to knowledge distillation. + See https://arxiv.org/pdf/2204.09656.pdf +
+ + +*Table 4.3: Description of some common pruning strategies.* + +Sparsity Pattern +^^^^^^^^^^^^^^^^ + +A sparsity pattern describes how the pruned parameters are arranged within the model / tensor. + +Recall that in general it is necessary to use optimized sparse kernels in order to achieve performance gains. Depending on the format and the sparsity level of the weight tensor, sparse matrix multiplication can be faster than its dense counterpart. It can also be slower if a tensor is not sufficiently sparse. + +At the most general level, pruning is unstructured -every parameter has it’s own mask. This gives the most flexibility but requires very high sparsity (>98%) in order to provide performance benefits. In order to provide accelerated inference at lower sparsity levels, hardware backends have added support for special sparsity patterns. + +We seek to prune the model so that the weight tensors exhibit the same sparsity pattern as our inference backend. If we are able to recover the accuracy lost while maintaining the sparsity pattern, we can run this model on sparse hardware for accelerated inference without an accuracy penalty. We can also run a model pruned to a different sparsity pattern on our target backend, at the expense of some additional accuracy loss. + +The specific backend hardware and its corresponding sparsity pattern, as well as the pruning configuration ultimately dictates the performance speedups that we observe. If we prune a model using a different pruning criteria it will have the same performance characteristics if it follows the same sparsity pattern and sparsity level. For example, if we decided to remove the highest-magnitude weights instead of the lowest-magnitude weights, we wouldn’t expect that to change the performance characteristics of the pruned model. + + +.. raw:: html + + + + + + + + + + + + + + + + + + + + + + +
Sparsity Pattern + Mask Visualization +

+ (50% sparsity level) +

Unstructured Sparsity + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Fig 2.3: unstructured sparsity +
1 + 0 + 1 + 1 + 0 + 1 + 0 + 1 +
0 + 0 + 1 + 1 + 1 + 1 + 1 + 0 +
1 + 0 + 0 + 0 + 1 + 0 + 1 + 0 +
0 + 1 + 1 + 0 + 0 + 0 + 0 + 1 +
+ + +
2:4 Semi-Structured + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Fig 2.4: 2:4 semi-structured sparsity +
0 + 1 + 1 + 0 + 1 + 0 + 1 + 0 +
0 + 0 + 1 + 1 + 1 + 1 + 0 + 0 +
1 + 0 + 0 + 1 + 0 + 1 + 0 + 1 +
0 + 1 + 0 + 1 + 1 + 0 + 1 + 0 +
+ +
Block Sparsity + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Fig 2.5: 4x4 block-wise structured sparsity +
0 + 0 + 0 + 0 + 1 + 1 + 1 + 1 +
0 + 0 + 0 + 0 + 1 + 1 + 1 + 1 +
0 + 0 + 0 + 0 + 1 + 1 + 1 + 1 +
0 + 0 + 0 + 0 + 1 + 1 + 1 + 1 +
+ +
Structured Sparsity + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Fig 2.6: row-wise structured sparsity +
1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 +
0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 +
1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 +
0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 +
+
+ +*Table 4.4: Description of some common sparsity patterns.* + +For more information on our supported APIs and benchmaks please refer `Sparsity README `_. From 602ba86e3fbff201bc32e4e8e74b9fe89321f9e2 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 23 Jan 2025 08:09:59 -0800 Subject: [PATCH 04/15] gate sparsity tests by presence of cusparselt (#1602) Summary: I have a PyTorch build without `cuSparseLt`. Adding logic to properly skip tests which depend on this library being available. Test Plan: Local testing on an H100 without cuSparseLt: ``` pytest test/prototype/test_sparse_api.py -s ``` Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_affine_quantized.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index f08ba7aa72..8be0652e9a 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -23,6 +23,10 @@ is_sm_at_least_89, ) +is_cusparselt_available = ( + hasattr(torch.backends, "cusparselt") and torch.backends.cusparselt.is_available() +) + def get_quantization_functions( do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False @@ -91,7 +95,8 @@ def test_tensor_core_layout_transpose(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize( - "apply_quant", get_quantization_functions(True, True, "cuda", True) + "apply_quant", + get_quantization_functions(is_cusparselt_available, True, "cuda", True), ) def test_weights_only(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") @@ -168,7 +173,9 @@ def apply_uint6_weight_only_quant(linear): deregister_aqt_quantized_linear_dispatch(dispatch_condition) - @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) + @common_utils.parametrize( + "apply_quant", get_quantization_functions(is_cusparselt_available, True) + ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_print_quantized_module(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") From d0e434c8d825f7ac69e26585cb2ceb002a287f24 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Thu, 23 Jan 2025 12:43:34 -0500 Subject: [PATCH 05/15] Fix broken link on doc page (#1582) --- docs/source/_templates/layout.html | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index 6bb2207266..f1d3173de2 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -2,7 +2,7 @@ {% block sidebartitle %} {% include "searchbox.html" %} {% endblock %} @@ -22,7 +22,7 @@ // to point to the torchao repo. var overwrite = function (_) { if ($(this).length > 0) { - $(this)[0].href = "https://github.com/pytorch-labs/ao" + $(this)[0].href = "https://github.com/pytorch/ao" } } // PC From e53edaa8a0d31bfc10d5a184c0178787e1a011ac Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 23 Jan 2025 12:02:44 -0800 Subject: [PATCH 06/15] pin nightlies to 20250122 (#1608) Summary: There are test failures with the 20250123 nightly: ``` if not output_graph.export: if not self.guard_manager.check(output_graph.local_scope): reasons = get_guard_fail_reason_helper( self.guard_manager, # type: ignore[arg-type] output_graph.local_scope, CompileContext.current_compile_id(), ) > raise AssertionError(f"Guard check failed: {reasons}") E AssertionError: Guard check failed: 0/0: ___check_metadata_140011526812544_c0/0 E E E You can suppress this exception and fall back to eager by setting: E import torch._dynamo E torch._dynamo.config.suppress_errors = True /home/vasiliy/.conda/envs/pt_nightly_20241006/lib/python3.11/site-packages/torch/_dynamo/guards.py:2468: AssertionError ``` full example: https://ossci-raw-job-status.s3.amazonaws.com/log/pytorch/ao/36071578472 Pin to the previous day for now until the problem is fixed in pytorch/pytorch Test Plan: Reviewers: Subscribers: Tasks: Tags: --- .github/workflows/float8_test.yml | 4 ++-- .github/workflows/nightly_smoke_test.yml | 4 ++-- .github/workflows/regression_test.yml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/float8_test.yml b/.github/workflows/float8_test.yml index 7c9e5a4b00..b77a50ed2c 100644 --- a/.github/workflows/float8_test.yml +++ b/.github/workflows/float8_test.yml @@ -25,9 +25,9 @@ jobs: include: - name: SM-89 runs-on: linux.g6.4xlarge.experimental.nvidia.gpu - torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' + torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cu124' gpu-arch-type: "cuda" - gpu-arch-version: "12.1" + gpu-arch-version: "12.4" permissions: id-token: write diff --git a/.github/workflows/nightly_smoke_test.yml b/.github/workflows/nightly_smoke_test.yml index 18d4f41af6..57486bf58f 100644 --- a/.github/workflows/nightly_smoke_test.yml +++ b/.github/workflows/nightly_smoke_test.yml @@ -21,9 +21,9 @@ jobs: include: - name: CUDA Nightly runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' + torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cu124' gpu-arch-type: "cuda" - gpu-arch-version: "12.1" + gpu-arch-version: "12.4" permissions: id-token: write diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 19c033c4d1..14c31014c3 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -25,12 +25,12 @@ jobs: include: - name: CUDA Nightly runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu124' + torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cu124' gpu-arch-type: "cuda" gpu-arch-version: "12.4" - name: CPU Nightly runs-on: linux.4xlarge - torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu' + torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cpu' gpu-arch-type: "cpu" gpu-arch-version: "" From 52280bbb69e29ccde28b529157e313f849bd9ff0 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 23 Jan 2025 15:59:23 -0800 Subject: [PATCH 07/15] [BE] Only run docs build in CI if docs have changed (#1589) only run docs build in CI if docs have changed --- .github/workflows/doc_build.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/doc_build.yml b/.github/workflows/doc_build.yml index 19c1204e6d..d16ed0340b 100644 --- a/.github/workflows/doc_build.yml +++ b/.github/workflows/doc_build.yml @@ -9,6 +9,9 @@ on: tags: - v[0-9]+.[0-9]+.[0-9] - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + paths: + - 'docs/**' + - '!docs/**' pull_request: workflow_dispatch: From 2d4c8482d306c18796fb6d478fac2bcc410f9487 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 23 Jan 2025 16:00:48 -0800 Subject: [PATCH 08/15] [float8nocompile] Add float8nocompile CI tests which only trigger on relevant code changes (#1570) add float8nocompile CI tests --- .github/workflows/float8nocompile_test.yaml | 55 +++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 .github/workflows/float8nocompile_test.yaml diff --git a/.github/workflows/float8nocompile_test.yaml b/.github/workflows/float8nocompile_test.yaml new file mode 100644 index 0000000000..75df32a5d4 --- /dev/null +++ b/.github/workflows/float8nocompile_test.yaml @@ -0,0 +1,55 @@ +name: Run Float8nocompile Tests + +on: + push: + branches: + - main + - 'gh/**' + paths: + - 'torchao/prototype/float8nocompile/**' + - '!torchao/prototype/float8nocompile/**' + pull_request: + branches: + - main + - 'gh/**' + paths: + - 'torchao/prototype/float8nocompile/**' + - '!torchao/prototype/float8nocompile/**' + +concurrency: + group: floatnocompile_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + cancel-in-progress: true + +env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + +jobs: + test: + strategy: + fail-fast: false + matrix: + include: + - name: SM-89 + runs-on: linux.g6.4xlarge.experimental.nvidia.gpu + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' + gpu-arch-type: "cuda" + gpu-arch-version: "12.1" + + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + timeout: 300 + runner: ${{ matrix.runs-on }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + submodules: recursive + script: | + conda create -n venv python=3.9 -y + conda activate venv + export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH + python -m pip install --upgrade pip + pip install ${{ matrix.torch-spec }} + pip install -r dev-requirements.txt + pip install . + cd torchao/prototype/float8nocompile + pytest kernels/ --verbose -s + pytest test/train_test.py --verbose -s From 4ed93b996b0dc9abd6ac105fec7c9fa52e9a23b3 Mon Sep 17 00:00:00 2001 From: Xia Weiwen Date: Thu, 23 Jan 2025 17:47:20 -0800 Subject: [PATCH 09/15] [CPU] Fix registration of int4wo linear implementation on CPU (#1578) * [CPU] Fix registration of int4wo linear implementation on CPU * Fix format issues * Fix format issues (2) * Fix bug for 3d input * fix format issue * Remove autocast from UT --- test/quantization/test_quant_api.py | 22 +++++ torchao/dtypes/affine_quantized_tensor_ops.py | 8 ++ torchao/dtypes/uintx/int4_cpu_layout.py | 86 ++++++++++++++++++- .../dtypes/uintx/tensor_core_tiled_layout.py | 12 +-- torchao/quantization/quant_api.py | 4 +- 5 files changed, 118 insertions(+), 14 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 177c357047..caba1cf31f 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -761,6 +761,28 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + def test_int4wo_cpu(self, dtype, x_dim): + from torchao.dtypes import Int4CPULayout + + device = "cpu" + m = ToyLinearModel().eval().to(dtype).to(device) + example_inputs = m.example_inputs(dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + + with torch.no_grad(): + quantize_(m, int4_weight_only(group_size=32, layout=Int4CPULayout())) + # ensure the expected op is in the code + _, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + assert "_weight_int4pack_mm_for_cpu" in code[0] + assert "aten.mm.default" not in code[0] + class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 76df949852..ef8691699e 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -28,6 +28,10 @@ _linear_fp_act_int4_weight_gemlite_check, _linear_fp_act_int4_weight_gemlite_impl, ) +from torchao.dtypes.uintx.int4_cpu_layout import ( + _linear_fp_act_uint4_weight_cpu_check, + _linear_fp_act_uint4_weight_cpu_impl, +) from torchao.dtypes.uintx.marlin_qqq_tensor import ( _linear_int8_act_int4_weight_marlin_qqq_check, _linear_int8_act_int4_weight_marlin_qqq_impl, @@ -151,6 +155,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ), + ( + _linear_fp_act_uint4_weight_cpu_check, + _linear_fp_act_uint4_weight_cpu_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 248f7e1b94..7c734a8a44 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -2,10 +2,17 @@ from typing import Optional, Tuple import torch -from torch.utils._python_dispatch import return_and_correct_aliasing +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) -from torchao.dtypes.affine_quantized_tensor import register_layout +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device +from torchao.quantization.quant_primitives import ZeroPointDomain from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, @@ -126,7 +133,7 @@ def from_plain( zero_point = zero_point.reshape(int_data.shape[0], -1) from torchao.quantization.utils import pack_tinygemm_scales_and_zeros - scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) return cls(packed_weight, scale_and_zero, False, _layout) def to(self, *args, **kwargs): @@ -231,7 +238,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: groupsize = int(original_shape[1] / scale.shape[-2]) block_size = (1, groupsize) device = self.device - original_dtype = torch.bfloat16 + original_dtype = self.scale_and_zero.dtype target_dtype = torch.int32 quant_min = 0 quant_max = 15 @@ -261,3 +268,74 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def get_layout(self) -> Layout: return self._layout + + +def _aqt_is_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 15 + ) + + +def _is_float(dtype): + return dtype in (torch.float, torch.half, torch.bfloat16) + + +def _linear_fp_act_uint4_weight_cpu_check(input_tensor, weight_tensor, bias): + return ( + TORCH_VERSION_AT_LEAST_2_6 + and is_device(input_tensor.device.type, "cpu") + and is_device(weight_tensor.device.type, "cpu") + and (bias is None or is_device(bias.device.type, "cpu")) + and not is_traceable_wrapper_subclass(input_tensor) + and _is_float(input_tensor.dtype) + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_uint4(weight_tensor) + and _is_float(weight_tensor.dtype) + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT + and isinstance(weight_tensor._layout, Int4CPULayout) + ) + + +def _linear_fp_act_uint4_weight_cpu_impl(input_tensor, weight_tensor, bias): + assert ( + TORCH_VERSION_AT_LEAST_2_6 + ), f"Requires PyTorch version at least 2.6, but got: {torch.__version__}" + assert is_device( + input_tensor.device.type, "cpu" + ), f"For CPU device only but got: {input_tensor.device}" + assert ( + weight_tensor.block_size[0] == 1 + ), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + act_mat = input_tensor + packed_weight = weight_tensor.tensor_impl.packed_weight + scale_and_zero = weight_tensor.tensor_impl.scale_and_zero + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + y = torch.ops.aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + if bias is not None: + y += bias + return y.to(orig_dtype) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 7de869df2d..378744e7e1 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -15,7 +15,6 @@ from torchao.quantization.quant_primitives import ZeroPointDomain, _get_reduction_params from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, fill_defaults, find_multiple, ) @@ -76,14 +75,9 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): # groupwise int4 quantization groupsize = weight_tensor.block_size[1] - if is_device(input_tensor.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: - y = torch.ops.aten._weight_int4pack_mm_for_cpu( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) - else: - y = torch.ops.aten._weight_int4pack_mm( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b2eff196fd..3a73b97ad1 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -725,7 +725,9 @@ def apply_int4_weight_only_quant(weight): quant_max = 15 eps = 1e-6 preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] - zero_point_dtype = torch.bfloat16 + zero_point_dtype = ( + weight.dtype if isinstance(layout, Int4CPULayout) else torch.bfloat16 + ) nonlocal zero_point_domain assert ( From 0fae69377ea9ec7e16e2e27f489e7b8c9c992b5c Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 24 Jan 2025 10:06:25 -0800 Subject: [PATCH 10/15] Add H100 to Float8 CI for testing (#1575) --- .github/workflows/float8_test.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/float8_test.yml b/.github/workflows/float8_test.yml index b77a50ed2c..3cf2d13933 100644 --- a/.github/workflows/float8_test.yml +++ b/.github/workflows/float8_test.yml @@ -28,6 +28,11 @@ jobs: torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cu124' gpu-arch-type: "cuda" gpu-arch-version: "12.4" + - name: H100 + runs-on: linux.aws.h100 + torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124' + gpu-arch-type: "cuda" + gpu-arch-version: "12.4" permissions: id-token: write From 4e4f4df091ce50d1a97a34f156f4b667f894aac4 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 24 Jan 2025 13:43:51 -0500 Subject: [PATCH 11/15] Add quick start guide for first time users (#1611) Documentation in torchao has been pretty low-level and geared towards developers so far. This commit adds a basic quick start guide for first time users to get familiar with our main quantization flow. --- .gitignore | 2 +- docs/source/contributor_guide.rst | 2 +- docs/source/getting-started.rst | 4 - docs/source/index.rst | 17 ++-- docs/source/overview.rst | 4 - docs/source/quantization.rst | 6 +- docs/source/quick_start.rst | 136 ++++++++++++++++++++++++++++++ docs/source/sparsity.rst | 6 +- scripts/quick_start.py | 61 ++++++++++++++ 9 files changed, 213 insertions(+), 25 deletions(-) delete mode 100644 docs/source/getting-started.rst delete mode 100644 docs/source/overview.rst create mode 100644 docs/source/quick_start.rst create mode 100644 scripts/quick_start.py diff --git a/.gitignore b/.gitignore index 5fa7064cbe..726d2976f6 100644 --- a/.gitignore +++ b/.gitignore @@ -262,7 +262,7 @@ docs/dev docs/build docs/source/tutorials/* docs/source/gen_modules/* -docs/source/sg_execution_times +docs/source/sg_execution_times.rst # LevelDB files *.sst diff --git a/docs/source/contributor_guide.rst b/docs/source/contributor_guide.rst index a69c410e6c..e76b9420d0 100644 --- a/docs/source/contributor_guide.rst +++ b/docs/source/contributor_guide.rst @@ -1,4 +1,4 @@ -torchao Contributor Guide +Contributor Guide ------------------------- .. toctree:: diff --git a/docs/source/getting-started.rst b/docs/source/getting-started.rst deleted file mode 100644 index 70ac60b4a0..0000000000 --- a/docs/source/getting-started.rst +++ /dev/null @@ -1,4 +0,0 @@ -Getting Started -=============== - -TBA diff --git a/docs/source/index.rst b/docs/source/index.rst index 3bbcd203fd..04a53ce454 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,26 +1,25 @@ Welcome to the torchao Documentation -======================================= +==================================== -`torchao `__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README `__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on: - -1. Getting Started -2. Developer Notes -3. API Reference -4. Tutorials +`torchao `__ is a library for custom data types and optimizations. +Quantize and sparsify weights, gradients, optimizers, and activations for inference and training +using native PyTorch. Please checkout torchao `README `__ +for an overall introduction to the library and recent highlight and updates. .. toctree:: :glob: :maxdepth: 1 :caption: Getting Started - getting-started - sparsity + quick_start .. toctree:: :glob: :maxdepth: 1 :caption: Developer Notes + quantization + sparsity contributor_guide .. toctree:: diff --git a/docs/source/overview.rst b/docs/source/overview.rst deleted file mode 100644 index 4c6d532067..0000000000 --- a/docs/source/overview.rst +++ /dev/null @@ -1,4 +0,0 @@ -Overview -======== - -TBA diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index d96a3afc18..b5e34780b7 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -1,4 +1,4 @@ -Quantization -============ +Quantization Overview +--------------------- -TBA +Coming soon! diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst new file mode 100644 index 0000000000..fea8bb912d --- /dev/null +++ b/docs/source/quick_start.rst @@ -0,0 +1,136 @@ +Quick Start Guide +----------------- + +In this quick start guide, we will explore how to perform basic quantization using torchao. +First, install the latest stable torchao release:: + + pip install torchao + +If you prefer to use the nightly release, you can install torchao using the following +command instead:: + + pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 + +torchao is compatible with the latest 3 major versions of PyTorch, which you will also +need to install (`detailed instructions `__):: + + pip install torch + + +First Quantization Example +========================== + +The main entry point for quantization in torchao is the `quantize_ `__ API. +This function mutates your model inplace to insert the custom quantization logic based +on what the user configures. All code in this guide can be found in this `example script `__. +First, let's set up our toy model: + +.. code:: py + + import copy + import torch + + class ToyLinearModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + + # Optional: compile model for faster inference and generation + model = torch.compile(model, mode="max-autotune", fullgraph=True) + model_bf16 = copy.deepcopy(model) + +Now we call our main quantization API to quantize the linear weights +in the model to int4 inplace. More specifically, this applies uint4 +weight-only asymmetric per-group quantization, leveraging the +`tinygemm int4mm CUDA kernel `__ +for efficient mixed dtype matrix multiplication: + +.. code:: py + + # torch 2.4+ only + from torchao.quantization import int4_weight_only, quantize_ + quantize_(model, int4_weight_only(group_size=32)) + +The quantized model is now ready to use! Note that the quantization +logic is inserted through tensor subclasses, so there is no change +to the overall model structure; only the weights tensors are updated, +but `nn.Linear` modules stay as `nn.Linear` modules: + +.. code:: py + + >>> model.linear1 + Linear(in_features=1024, out_features=1024, weight=AffineQuantizedTensor(shape=torch.Size([1024, 1024]), block_size=(1, 32), device=cuda:0, _layout=TensorCoreTiledLayout(inner_k_tiles=8), tensor_impl_dtype=torch.int32, quant_min=0, quant_max=15)) + + >>> model.linear2 + Linear(in_features=1024, out_features=1024, weight=AffineQuantizedTensor(shape=torch.Size([1024, 1024]), block_size=(1, 32), device=cuda:0, _layout=TensorCoreTiledLayout(inner_k_tiles=8), tensor_impl_dtype=torch.int32, quant_min=0, quant_max=15)) + +First, verify that the int4 quantized model is roughly a quarter of +the size of the original bfloat16 model: + +.. code:: py + + >>> import os + >>> torch.save(model, "/tmp/int4_model.pt") + >>> torch.save(model_bf16, "/tmp/bfloat16_model.pt") + >>> int4_model_size_mb = os.path.getsize("/tmp/int4_model.pt") / 1024 / 1024 + >>> bfloat16_model_size_mb = os.path.getsize("/tmp/bfloat16_model.pt") / 1024 / 1024 + + >>> print("int4 model size: %.2f MB" % int4_model_size_mb) + int4 model size: 1.25 MB + + >>> print("bfloat16 model size: %.2f MB" % bfloat16_model_size_mb) + bfloat16 model size: 4.00 MB + +Next, we demonstrate that not only is the quantized model smaller, +it is also much faster! + +.. code:: py + + from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + benchmark_model, + unwrap_tensor_subclass, + ) + + # Temporary workaround for tensor subclass + torch.compile + # Only needed for torch version < 2.5 + if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(model) + + num_runs = 100 + torch._dynamo.reset() + example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),) + bf16_time = benchmark_model(model_bf16, num_runs, example_inputs) + int4_time = benchmark_model(model, num_runs, example_inputs) + + print("bf16 mean time: %0.3f ms" % bf16_time) + print("int4 mean time: %0.3f ms" % int4_time) + print("speedup: %0.1fx" % (bf16_time / int4_time)) + +On a single A100 GPU with 80GB memory, this prints:: + + bf16 mean time: 30.393 ms + int4 mean time: 4.410 ms + speedup: 6.9x + + +Next Steps +========== + +In this quick start guide, we learned how to quantize a simple model with +torchao. To learn more about the different workflows supported in torchao, +see our main `README `__. +For a more detailed overview of quantization in torchao, visit +`this page `__. + +Finally, if you would like to contribute to torchao, don't forget to check +out our `contributor guide `__ and our list of +`good first issues `__ on Github! diff --git a/docs/source/sparsity.rst b/docs/source/sparsity.rst index 0bde173b6d..d9986a3227 100644 --- a/docs/source/sparsity.rst +++ b/docs/source/sparsity.rst @@ -1,5 +1,5 @@ -Sparsity --------- +Sparsity Overview +----------------- Sparsity is the technique of removing parameters from a neural network in order to reduce its memory overhead or latency. By carefully choosing how the elements are pruned, one can achieve significant reduction in memory overhead and latency, while paying a reasonably low or no price in terms of model quality (accuracy / f1). @@ -38,7 +38,7 @@ Given a target sparsity pattern, pruning/sparsifying a model can then be thought * **Accuracy** - How can I find a set of sparse weights which satisfy my target sparsity pattern that minimize the accuracy degradation of my model? -* **Perforance** - How can I accelerate my sparse weights for inference and reduce memory overhead? +* **Performance** - How can I accelerate my sparse weights for inference and reduce memory overhead? Our workflow is designed to consist of two parts that answer each question independently: diff --git a/scripts/quick_start.py b/scripts/quick_start.py new file mode 100644 index 0000000000..f2e195fd7e --- /dev/null +++ b/scripts/quick_start.py @@ -0,0 +1,61 @@ +import copy + +import torch + +from torchao.quantization import int4_weight_only, quantize_ +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + benchmark_model, + unwrap_tensor_subclass, +) + +# ================ +# | Set up model | +# ================ + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + +# Optional: compile model for faster inference and generation +model = torch.compile(model, mode="max-autotune", fullgraph=True) +model_bf16 = copy.deepcopy(model) + + +# ======================== +# | torchao quantization | +# ======================== + +# torch 2.4+ only +quantize_(model, int4_weight_only(group_size=32)) + + +# ============= +# | Benchmark | +# ============= + +# Temporary workaround for tensor subclass + torch.compile +# Only needed for torch version < 2.5 +if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(model) + +num_runs = 100 +torch._dynamo.reset() +example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),) +bf16_time = benchmark_model(model_bf16, num_runs, example_inputs) +int4_time = benchmark_model(model, num_runs, example_inputs) + +print("bf16 mean time: %0.3f ms" % bf16_time) +print("int4 mean time: %0.3f ms" % int4_time) +print("speedup: %0.1fx" % (bf16_time / int4_time)) From 70be2452f3ae4fbd13ab61609732878baa990c84 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 24 Jan 2025 11:27:48 -0800 Subject: [PATCH 12/15] Move fpx to tensor subclass (#1603) --- torchao/dtypes/__init__.py | 6 +- torchao/dtypes/affine_quantized_tensor.py | 87 +++++-------------- torchao/dtypes/floatx/__init__.py | 4 + .../floatx/floatx_tensor_core_layout.py | 57 ++++++++++++ 4 files changed, 87 insertions(+), 67 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 9cbd4cd2a0..d043a13af9 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -4,12 +4,14 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, # experimental, will be merged into floatx in the future - to_affine_quantized_fpx, to_affine_quantized_intx, to_affine_quantized_intx_static, ) from .floatx import ( Float8Layout, + FloatxTensor, + FloatxTensorCoreLayout, + to_affine_quantized_fpx, ) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( @@ -52,4 +54,6 @@ "MarlinQQQLayout", "Int4CPULayout", "CutlassInt4PackedLayout", + "FloatxTensor", + "FloatxTensorCoreLayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e7aca34c5f..eedca7e1cb 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -14,12 +14,9 @@ MappingType, ZeroPointDomain, choose_qparams_affine, - choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, dequantize_affine, - dequantize_affine_floatx, quantize_affine, - quantize_affine_floatx, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -36,7 +33,6 @@ "to_affine_quantized_floatx", "to_affine_quantized_intx_static", "to_affine_quantized_floatx_static", - "to_affine_quantized_fpx", ] @@ -126,40 +122,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor if output_dtype is None: output_dtype = self.dtype - from torchao.dtypes.floatx import FloatxTensorCoreLayout - - if isinstance(self._layout, FloatxTensorCoreLayout): - int_data, scale = self.tensor_impl.get_plain() - return dequantize_affine_floatx( - int_data, - scale, - self._layout.ebits, - self._layout.mbits, - output_dtype=output_dtype, - ) - else: - data, scale, zero_point = self.tensor_impl.get_plain() - dq = dequantize_affine( - data, - self.block_size, - scale, - zero_point, - data.dtype, - self.quant_min, - self.quant_max, - self.zero_point_domain, - output_dtype=output_dtype, - ) - from torchao.dtypes.uintx import TensorCoreTiledLayout + data, scale, zero_point = self.tensor_impl.get_plain() + dq = dequantize_affine( + data, + self.block_size, + scale, + zero_point, + data.dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain, + output_dtype=output_dtype, + ) + from torchao.dtypes.uintx import TensorCoreTiledLayout - if isinstance(self._layout, TensorCoreTiledLayout): - # need to return to original shape if tensor was padded - # in preprocessing - # TODO: we could add an API for this if there are more use cases - # (e.g. dequant_post_process) in TensorImpl or Layout - for dim, dim_size in enumerate(self.shape): - dq = dq.narrow(dim, 0, dim_size) - return dq + if isinstance(self._layout, TensorCoreTiledLayout): + # need to return to original shape if tensor was padded + # in preprocessing + # TODO: we could add an API for this if there are more use cases + # (e.g. dequant_post_process) in TensorImpl or Layout + for dim, dim_size in enumerate(self.shape): + dq = dq.narrow(dim, 0, dim_size) + return dq def __tensor_flatten__(self): return ["tensor_impl"], [ @@ -395,33 +379,6 @@ def from_hp_to_floatx_static( f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static" ) - @classmethod - def from_hp_to_fpx( - cls, - input_float: torch.Tensor, - _layout: Layout, - ): - from torchao.dtypes.floatx import FloatxTensorCoreLayout - - assert isinstance( - _layout, FloatxTensorCoreLayout - ), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" - original_shape = input_float.shape - input_float = _layout.pre_process(input_float) - # per axis quantization, where axis = 1 - block_size = list(input_float.shape) - block_size[1] = 1 - - ebits, mbits = _layout.ebits, _layout.mbits - # Note: these ops are hardcoded to have per axis quantization (axis=1) right now - scale = choose_qparams_affine_floatx(input_float, ebits, mbits) - floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) - floatx_packed = _layout.post_process(floatx_unpacked) - - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) - return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) - @property def _layout(self) -> Layout: return self.tensor_impl._layout @@ -477,8 +434,6 @@ def _apply_fn_to_data(self, fn): to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static -# experimental will be merged in to floatx -to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx if TORCH_VERSION_AT_LEAST_2_5: # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 3f0a1ccd5c..4bfaa3de9e 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,7 +1,9 @@ from .float8_layout import Float8Layout from .floatx_tensor_core_layout import ( + FloatxTensor, FloatxTensorCoreLayout, from_scaled_tc_floatx, + to_affine_quantized_fpx, to_scaled_tc_floatx, ) @@ -10,4 +12,6 @@ "to_scaled_tc_floatx", "from_scaled_tc_floatx", "Float8Layout", + "to_affine_quantized_fpx", + "FloatxTensor", ] diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index 0f67e9826e..99d07fd4e0 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -11,6 +11,7 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, + get_tensor_impl_constructor, register_layout, ) from torchao.dtypes.utils import ( @@ -22,6 +23,11 @@ _floatx_unpacked_to_f32, _n_ones, ) +from torchao.quantization.quant_primitives import ( + choose_qparams_affine_floatx, + dequantize_affine_floatx, + quantize_affine_floatx, +) aten = torch.ops.aten _ONES_TABLE = [_n_ones(i) for i in range(8)] @@ -456,6 +462,54 @@ class FloatxTensorCoreLayout(Layout): mbits: int +class FloatxTensor(AffineQuantizedTensor): + """ + Floatx quantized tensor subclass which inherits AffineQuantizedTensor class. It uses floating-point format defined by ebits (exponent bits) and mbits (mantissa bits) and supports float1 - float7 tensor types. + For details about float8 tensor type, please refer to https://github.com/pytorch/ao/blob/main/torchao/dtypes/floatx/float8_layout.py. + + To see what happens during choose_qparams_and_quantize_affine_fpx, quantization and dequantization for floatx quantization, + please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py + and check the two quant primitive ops: choose_qparams_affine_floatx, quantize_affine_floatx and dequantize_affine_floatx. + """ + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + int_data, scale = self.tensor_impl.get_plain() + return dequantize_affine_floatx( + int_data, + scale, + self._layout.ebits, + self._layout.mbits, + output_dtype=output_dtype, + ) + + @classmethod + def from_hp_to_floatx( + cls, + input_float: torch.Tensor, + _layout: Layout, + ): + assert isinstance( + _layout, FloatxTensorCoreLayout + ), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + # per axis quantization, where axis = 1 + block_size = list(input_float.shape) + block_size[1] = 1 + + ebits, mbits = _layout.ebits, _layout.mbits + # Note: these ops are hardcoded to have per axis quantization (axis=1) right now + scale = choose_qparams_affine_floatx(input_float, ebits, mbits) + floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) + floatx_packed = _layout.post_process(floatx_unpacked) + + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) + return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) + + @register_layout(FloatxTensorCoreLayout) class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), @@ -657,3 +711,6 @@ def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): out += bias return out.view(*act.shape[:-1], out_dim).to(act.dtype) + + +to_affine_quantized_fpx = FloatxTensor.from_hp_to_floatx From fb335e08f1c970f3c9b1f0eb7d214cfeded7fbaf Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 24 Jan 2025 11:57:33 -0800 Subject: [PATCH 13/15] Revert "Move fpx to tensor subclass" (#1616) Revert "Move fpx to tensor subclass (#1603)" This reverts commit 70be2452f3ae4fbd13ab61609732878baa990c84. --- torchao/dtypes/__init__.py | 6 +- torchao/dtypes/affine_quantized_tensor.py | 87 ++++++++++++++----- torchao/dtypes/floatx/__init__.py | 4 - .../floatx/floatx_tensor_core_layout.py | 57 ------------ 4 files changed, 67 insertions(+), 87 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d043a13af9..9cbd4cd2a0 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -4,14 +4,12 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, # experimental, will be merged into floatx in the future + to_affine_quantized_fpx, to_affine_quantized_intx, to_affine_quantized_intx_static, ) from .floatx import ( Float8Layout, - FloatxTensor, - FloatxTensorCoreLayout, - to_affine_quantized_fpx, ) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( @@ -54,6 +52,4 @@ "MarlinQQQLayout", "Int4CPULayout", "CutlassInt4PackedLayout", - "FloatxTensor", - "FloatxTensorCoreLayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index eedca7e1cb..e7aca34c5f 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -14,9 +14,12 @@ MappingType, ZeroPointDomain, choose_qparams_affine, + choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, dequantize_affine, + dequantize_affine_floatx, quantize_affine, + quantize_affine_floatx, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -33,6 +36,7 @@ "to_affine_quantized_floatx", "to_affine_quantized_intx_static", "to_affine_quantized_floatx_static", + "to_affine_quantized_fpx", ] @@ -122,28 +126,40 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor if output_dtype is None: output_dtype = self.dtype - data, scale, zero_point = self.tensor_impl.get_plain() - dq = dequantize_affine( - data, - self.block_size, - scale, - zero_point, - data.dtype, - self.quant_min, - self.quant_max, - self.zero_point_domain, - output_dtype=output_dtype, - ) - from torchao.dtypes.uintx import TensorCoreTiledLayout + from torchao.dtypes.floatx import FloatxTensorCoreLayout - if isinstance(self._layout, TensorCoreTiledLayout): - # need to return to original shape if tensor was padded - # in preprocessing - # TODO: we could add an API for this if there are more use cases - # (e.g. dequant_post_process) in TensorImpl or Layout - for dim, dim_size in enumerate(self.shape): - dq = dq.narrow(dim, 0, dim_size) - return dq + if isinstance(self._layout, FloatxTensorCoreLayout): + int_data, scale = self.tensor_impl.get_plain() + return dequantize_affine_floatx( + int_data, + scale, + self._layout.ebits, + self._layout.mbits, + output_dtype=output_dtype, + ) + else: + data, scale, zero_point = self.tensor_impl.get_plain() + dq = dequantize_affine( + data, + self.block_size, + scale, + zero_point, + data.dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain, + output_dtype=output_dtype, + ) + from torchao.dtypes.uintx import TensorCoreTiledLayout + + if isinstance(self._layout, TensorCoreTiledLayout): + # need to return to original shape if tensor was padded + # in preprocessing + # TODO: we could add an API for this if there are more use cases + # (e.g. dequant_post_process) in TensorImpl or Layout + for dim, dim_size in enumerate(self.shape): + dq = dq.narrow(dim, 0, dim_size) + return dq def __tensor_flatten__(self): return ["tensor_impl"], [ @@ -379,6 +395,33 @@ def from_hp_to_floatx_static( f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static" ) + @classmethod + def from_hp_to_fpx( + cls, + input_float: torch.Tensor, + _layout: Layout, + ): + from torchao.dtypes.floatx import FloatxTensorCoreLayout + + assert isinstance( + _layout, FloatxTensorCoreLayout + ), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + # per axis quantization, where axis = 1 + block_size = list(input_float.shape) + block_size[1] = 1 + + ebits, mbits = _layout.ebits, _layout.mbits + # Note: these ops are hardcoded to have per axis quantization (axis=1) right now + scale = choose_qparams_affine_floatx(input_float, ebits, mbits) + floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) + floatx_packed = _layout.post_process(floatx_unpacked) + + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) + return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) + @property def _layout(self) -> Layout: return self.tensor_impl._layout @@ -434,6 +477,8 @@ def _apply_fn_to_data(self, fn): to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static +# experimental will be merged in to floatx +to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx if TORCH_VERSION_AT_LEAST_2_5: # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 4bfaa3de9e..3f0a1ccd5c 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,9 +1,7 @@ from .float8_layout import Float8Layout from .floatx_tensor_core_layout import ( - FloatxTensor, FloatxTensorCoreLayout, from_scaled_tc_floatx, - to_affine_quantized_fpx, to_scaled_tc_floatx, ) @@ -12,6 +10,4 @@ "to_scaled_tc_floatx", "from_scaled_tc_floatx", "Float8Layout", - "to_affine_quantized_fpx", - "FloatxTensor", ] diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index 99d07fd4e0..0f67e9826e 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -11,7 +11,6 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, - get_tensor_impl_constructor, register_layout, ) from torchao.dtypes.utils import ( @@ -23,11 +22,6 @@ _floatx_unpacked_to_f32, _n_ones, ) -from torchao.quantization.quant_primitives import ( - choose_qparams_affine_floatx, - dequantize_affine_floatx, - quantize_affine_floatx, -) aten = torch.ops.aten _ONES_TABLE = [_n_ones(i) for i in range(8)] @@ -462,54 +456,6 @@ class FloatxTensorCoreLayout(Layout): mbits: int -class FloatxTensor(AffineQuantizedTensor): - """ - Floatx quantized tensor subclass which inherits AffineQuantizedTensor class. It uses floating-point format defined by ebits (exponent bits) and mbits (mantissa bits) and supports float1 - float7 tensor types. - For details about float8 tensor type, please refer to https://github.com/pytorch/ao/blob/main/torchao/dtypes/floatx/float8_layout.py. - - To see what happens during choose_qparams_and_quantize_affine_fpx, quantization and dequantization for floatx quantization, - please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py - and check the two quant primitive ops: choose_qparams_affine_floatx, quantize_affine_floatx and dequantize_affine_floatx. - """ - - def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: - if output_dtype is None: - output_dtype = self.dtype - int_data, scale = self.tensor_impl.get_plain() - return dequantize_affine_floatx( - int_data, - scale, - self._layout.ebits, - self._layout.mbits, - output_dtype=output_dtype, - ) - - @classmethod - def from_hp_to_floatx( - cls, - input_float: torch.Tensor, - _layout: Layout, - ): - assert isinstance( - _layout, FloatxTensorCoreLayout - ), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" - original_shape = input_float.shape - input_float = _layout.pre_process(input_float) - # per axis quantization, where axis = 1 - block_size = list(input_float.shape) - block_size[1] = 1 - - ebits, mbits = _layout.ebits, _layout.mbits - # Note: these ops are hardcoded to have per axis quantization (axis=1) right now - scale = choose_qparams_affine_floatx(input_float, ebits, mbits) - floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) - floatx_packed = _layout.post_process(floatx_unpacked) - - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) - return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) - - @register_layout(FloatxTensorCoreLayout) class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), @@ -711,6 +657,3 @@ def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): out += bias return out.view(*act.shape[:-1], out_dim).to(act.dtype) - - -to_affine_quantized_fpx = FloatxTensor.from_hp_to_floatx From 6c3bc539155145de8b5dff02b68ddade0d4e67c5 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 24 Jan 2025 12:39:48 -0800 Subject: [PATCH 14/15] Update api_ref_dtypes docs (#1610) --- docs/source/api_ref_dtypes.rst | 33 ++++++++++++++--- torchao/dtypes/affine_quantized_tensor.py | 37 ++++++++++--------- torchao/dtypes/floatx/float8_layout.py | 6 +++ .../floatx/floatx_tensor_core_layout.py | 4 +- torchao/dtypes/nf4tensor.py | 4 +- torchao/dtypes/uintx/block_sparse_layout.py | 6 +++ .../uintx/cutlass_int4_packed_layout.py | 2 + torchao/dtypes/uintx/int4_cpu_layout.py | 7 ++-- torchao/dtypes/uintx/marlin_qqq_tensor.py | 6 ++- torchao/dtypes/uintx/marlin_sparse_layout.py | 11 ++++++ torchao/dtypes/uintx/semi_sparse_layout.py | 7 ++++ .../dtypes/uintx/tensor_core_tiled_layout.py | 10 ++--- torchao/dtypes/uintx/uintx_layout.py | 11 ++++++ torchao/dtypes/utils.py | 19 +++++++--- 14 files changed, 122 insertions(+), 41 deletions(-) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index fbe680953e..26e1266c09 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -6,19 +6,42 @@ torchao.dtypes .. currentmodule:: torchao.dtypes +Layouts and Tensor Subclasses +----------------------------- +.. autosummary:: + :toctree: generated/ + :nosignatures: + + NF4Tensor + AffineQuantizedTensor + Layout + PlainLayout + SemiSparseLayout + TensorCoreTiledLayout + Float8Layout + FloatxTensor + FloatxTensorCoreLayout + MarlinSparseLayout + BlockSparseLayout + UintxLayout + MarlinQQQTensor + MarlinQQQLayout + Int4CPULayout + CutlassInt4PackedLayout + +Quantization techniques +----------------------- .. autosummary:: :toctree: generated/ :nosignatures: - to_nf4 to_affine_quantized_intx to_affine_quantized_intx_static + to_affine_quantized_fpx to_affine_quantized_floatx to_affine_quantized_floatx_static - to_affine_quantized_fpx - NF4Tensor - AffineQuantizedTensor - + to_marlinqqq_quantized_intx + to_nf4 .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring of torchao.dtypes.nf4tensor.NF4Tensor.dequantize_scalers:6:Unexpected indentation. diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e7aca34c5f..e3ac420de7 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -44,9 +44,8 @@ # Tensor Subclass Definition # ############################## class AffineQuantizedTensor(TorchAOBaseTensor): - """ - Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: - quantized_tensor = float_tensor / scale + zero_point + """Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: + quantized_tensor = float_tensor / scale + zero_point To see what happens during choose_qparams, quantization and dequantization for affine quantization, please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py @@ -56,21 +55,18 @@ class AffineQuantizedTensor(TorchAOBaseTensor): regardless of the internal representation's type or orientation. fields: - tensor_impl (AQTTensorImpl): tensor that serves as a general tensor impl storage for the quantized data, - e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device - and operator/kernel - block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam - e.g. when size is the same as the input tensor dimension, we are using per tensor quantization - shape (torch.Size): the shape for the original high precision Tensor - quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float - if zero_point is in integer domain, zero point is added to the quantized integer value during - quantization - if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) - value during quantization - default is ZeroPointDomain.INT - dtype: dtype for original high precision tensor, e.g. torch.float32 + - tensor_impl (AQTTensorImpl): tensor that serves as a general tensor impl storage for the quantized data, + e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device and operator/kernel + - block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + - shape (torch.Size): the shape for the original high precision Tensor + - quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` + - quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` + - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization + default is ZeroPointDomain.INT + - dtype: dtype for original high precision tensor, e.g. torch.float32 """ @staticmethod @@ -207,6 +203,7 @@ def from_hp_to_intx( _layout: Layout = PlainLayout(), use_hqq: bool = False, ): + """Convert a high precision tensor to an integer affine quantized tensor.""" original_shape = input_float.shape input_float = _layout.pre_process(input_float) @@ -302,6 +299,7 @@ def from_hp_to_intx_static( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), ): + """Create an integer AffineQuantizedTensor from a high precision tensor using static parameters.""" if target_dtype not in FP8_TYPES: assert ( zero_point_domain is not None @@ -348,6 +346,7 @@ def from_hp_to_floatx( _layout: Layout, scale_dtype: Optional[torch.dtype] = None, ): + """Convert a high precision tensor to a float8 quantized tensor.""" if target_dtype in FP8_TYPES: return cls.from_hp_to_intx( input_float=input_float, @@ -378,6 +377,7 @@ def from_hp_to_floatx_static( target_dtype: torch.dtype, _layout: Layout, ): + """Create a float8 AffineQuantizedTensor from a high precision tensor using static parameters.""" if target_dtype in FP8_TYPES: return cls.from_hp_to_intx_static( input_float=input_float, @@ -401,6 +401,7 @@ def from_hp_to_fpx( input_float: torch.Tensor, _layout: Layout, ): + """Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7.""" from torchao.dtypes.floatx import FloatxTensorCoreLayout assert isinstance( diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index dd995fb157..5a7e1924b3 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -25,6 +25,12 @@ @dataclass(frozen=True) class Float8Layout(Layout): + """Represents the layout configuration for Float8 affine quantized tensors. + + Attributes: + mm_config (Optional[Float8MMConfig]): Configuration for matrix multiplication operations involving Float8 tensors. If None, default settings are used. + """ + mm_config: Optional[Float8MMConfig] = None diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index 0f67e9826e..beaa2e536e 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -450,7 +450,9 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> # quantization api integrations @dataclass(frozen=True) class FloatxTensorCoreLayout(Layout): - """Layout type for FloatxTensorCoreAQTTensorImpl""" + """FloatxTensorCoreLayout is a data class that defines the layout for a tensor with a specific number of exponent bits (ebits) and mantissa bits (mbits). + This layout is used in the context of quantization and packing of tensors optimized for TensorCore operations. + """ ebits: int mbits: int diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 14a8c2d43e..5ae06a1fe1 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -662,10 +662,9 @@ def dequantize_scalers( ) -> torch.Tensor: """Used to unpack the double quantized scalers - Args; + Args: input_tensor: Input tensor to convert to QLoRA format this is the quantized scalers in int8 format quantization_factor: Tensor of per_scaler_block quantization factors stored in inpt_weight.dtype - size: (n_scaler_blocks) scaler_block_size: Scaler block size to use for double quantization. """ @@ -953,6 +952,7 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256): + """Convert a given tensor to normalized float 4-bit tensor.""" return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size) diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 0670986b13..6681847608 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -27,6 +27,12 @@ @dataclass(frozen=True) class BlockSparseLayout(Layout): + """BlockSparseLayout is a data class that represents the layout of a block sparse matrix. + + Attributes: + blocksize (int): The size of the blocks in the sparse matrix. Default is 64. + """ + blocksize: int = 64 diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index a6412ec88c..9c0d0bb055 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -29,6 +29,8 @@ def _aqt_is_int4(aqt): @dataclass(frozen=True) class CutlassInt4PackedLayout(Layout): + """Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel.""" + pass diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 7c734a8a44..d587591ccc 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -24,15 +24,16 @@ @dataclass(frozen=True) class Int4CPULayout(Layout): - """Only for PyTorch version at least 2.6""" + """Layout class for int4 CPU layout for affine quantized tensor, used by tinygemm kernels `_weight_int4pack_mm_for_cpu`. + Only for PyTorch version at least 2.6 + """ pass @register_layout(Int4CPULayout) class Int4CPUAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, + """TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, used by tinygemm kernels `_weight_int4pack_mm_for_cpu` It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of dimension: [n][k / 2] (uint8 dtype) diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index b75d959b41..3a4253bb3f 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -29,8 +29,7 @@ class MarlinQQQTensor(AffineQuantizedTensor): - """ - MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. + """MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py @@ -58,6 +57,7 @@ def from_hp_to_intx( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, _layout: Optional[Layout] = None, ): + """Converts a floating point tensor to a Marlin QQQ quantized tensor.""" original_shape = input_float.shape input_float = _layout.pre_process(input_float) nbits = int(math.log2(quant_max - quant_min + 1)) @@ -81,6 +81,8 @@ def from_hp_to_intx( @dataclass(frozen=True) class MarlinQQQLayout(Layout): + """MarlinQQQLayout is a layout class for Marlin QQQ quantization.""" + pass diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index 2a84dd1813..22763eb0c2 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -71,6 +71,17 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b @dataclass(frozen=True) class MarlinSparseLayout(Layout): + """MarlinSparseLayout is a layout class for handling sparse tensor formats + specifically designed for the Marlin sparse kernel. This layout is used + to optimize the storage and computation of affine quantized tensors with + 2:4 sparsity patterns. + + The layout ensures that the tensor data is pre-processed and stored in a + format that is compatible with the Marlin sparse kernel operations. It + provides methods for preprocessing input tensors and managing the layout + of quantized tensors. + """ + def pre_process(self, input: torch.Tensor) -> torch.Tensor: """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index a554fd9bc6..3c35a4d8cd 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -66,6 +66,13 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( @dataclass(frozen=True) class SemiSparseLayout(Layout): + """SemiSparseLayout is a layout class for handling semi-structured sparse + matrices in affine quantized tensors. This layout is specifically designed + to work with the 2:4 sparsity pattern, where two out of every four elements + are pruned to zero. This class provides methods for preprocessing input + tensors to conform to this sparsity pattern. + """ + def pre_process(self, input: torch.Tensor) -> torch.Tensor: # prune to 2:4 if not already temp = input.detach() diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 378744e7e1..b29c9d167b 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -91,9 +91,10 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): @dataclass(frozen=True) class TensorCoreTiledLayout(Layout): - """ - inner_k_tiles is an internal argument for packing function of tensor core tiled layout - that can affect the performance of the matmul kernel + """TensorCoreTiledLayout is a layout class for handling tensor core tiled layouts in affine quantized tensors. It provides methods for pre-processing and post-processing tensors to fit the required layout for efficient computation on tensor cores. + + Attributes: + inner_k_tiles (int): An internal argument for the packing function of tensor core tiled layout that can affect the performance of the matmul kernel. Defaults to 8. """ inner_k_tiles: int = 8 @@ -149,8 +150,7 @@ def extra_repr(self): @register_layout(TensorCoreTiledLayout) class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, + """TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, used by tinygemm kernels `_weight_int4pack_mm` It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of diff --git a/torchao/dtypes/uintx/uintx_layout.py b/torchao/dtypes/uintx/uintx_layout.py index 29c2ae93fe..ef85319cd5 100644 --- a/torchao/dtypes/uintx/uintx_layout.py +++ b/torchao/dtypes/uintx/uintx_layout.py @@ -209,6 +209,17 @@ def _(func, types, args, kwargs): @dataclass(frozen=True) class UintxLayout(Layout): + """A layout class for Uintx tensors, which are tensors with elements packed into + smaller bit-widths than the standard 8-bit byte. This layout is used to define + how the data is stored and processed in UintxTensor objects. + + Attributes: + dtype (torch.dtype): The data type of the tensor elements, which determines + the bit-width used for packing. + pack_dim (int): The dimension along which the data is packed. Default is -1, + which indicates the last dimension. + """ + dtype: torch.dtype pack_dim: int = -1 diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 0952b2a4bf..45a0b4312d 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -27,6 +27,15 @@ @dataclass(frozen=True) class Layout: + """The Layout class serves as a base class for defining different data layouts for tensors. + It provides methods for pre-processing and post-processing tensors, as well as static + pre-processing with additional parameters like scale, zero_point, and block_size. + + The Layout class is designed to be extended by other layout classes that define specific + data representations and behaviors for tensors. It is used in conjunction with TensorImpl + classes to represent custom data layouts and how tensors interact with different operators. + """ + def pre_process(self, input: torch.Tensor) -> torch.Tensor: return input @@ -49,13 +58,13 @@ def extra_repr(self) -> str: return "" -""" -Plain Layout, the most basic Layout, also has no extra metadata, will typically be the default -""" - - @dataclass(frozen=True) class PlainLayout(Layout): + """PlainLayout is the most basic layout class, inheriting from the Layout base class. + It does not add any additional metadata or processing steps to the tensor. + Typically, this layout is used as the default when no specific layout is required. + """ + pass From 860da263936aedc153283210f2f86573830625dd Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 24 Jan 2025 15:52:22 -0500 Subject: [PATCH 15/15] Add module swap -> tensor subclass migration tutorial (#1596) Adds a migration tutorial from module swap to tensor subclass for expressing basic quantization. This is a simplified version of the existing subclass tutorials in torchao, removing layers of indirection like Layout and TensorImpl for ease of understanding. This commit also removes overlapping content from the existing contributor guide. Work was done with @bdhirsh. --- docs/source/contributor_guide.rst | 216 +-------- docs/source/index.rst | 2 + docs/source/subclass_advanced.rst | 4 + docs/source/subclass_basic.rst | 462 ++++++++++++++++++++ tutorials/examples/logging_subclass.py | 66 +++ tutorials/examples/quantized_module_swap.py | 72 +++ tutorials/examples/quantized_subclass.py | 183 ++++++++ 7 files changed, 790 insertions(+), 215 deletions(-) create mode 100644 docs/source/subclass_advanced.rst create mode 100644 docs/source/subclass_basic.rst create mode 100644 tutorials/examples/logging_subclass.py create mode 100644 tutorials/examples/quantized_module_swap.py create mode 100644 tutorials/examples/quantized_subclass.py diff --git a/docs/source/contributor_guide.rst b/docs/source/contributor_guide.rst index e76b9420d0..7d4d20cc65 100644 --- a/docs/source/contributor_guide.rst +++ b/docs/source/contributor_guide.rst @@ -125,7 +125,7 @@ On the top of the stack will be the final quantization algorithms and quantizati For demonstration purposes, let's say after previous step we have ``AffineQuantizedTensor`` and ``to_affine_quantized`` factory function defined. For simplicity, let's say ``to_affine_quantized`` takes a high precision floating point Tensor and a target_dtype (e.g. torch.int8) and converts it to an ``AffineQuantizedTensor`` with corresponding dtype. -Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in ``Tensor Subclass Developer Guide`` section. +Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in the `Writing Your Own Tensor Subclass `__ tutorial. Weight Only Quantization ######################## @@ -257,220 +257,6 @@ During Save/Load Since ``AffineQuantizedTensor`` weight is still a ``torch.Tensor``, save/load works the same way as the original high precision floating point model. See the `serialization doc `__ for more details. -Tensor Subclass Developer Guide -=============================== - -We have covered high level overview and how everything is connected together in the previous section, this section will focus on Tensor Subclasses, which is the main extension point we rely on to provide flexibility of supporting inference, training and fine tuning with low precision Tensors and composability with torch.compile, autograd, distributed primitives in these scenarios. - -Prerequisites -~~~~~~~~~~~~~ -Some externally available resources for tensor subclasses: - -* `tensor subclass doc `__ -* `Edward's podcast about tensor subclasses `__ -* `Tensor subclass zoo `__ - -Why Tensor Subclass? -~~~~~~~~~~~~~~~~~~~~ -There are multiple ways people can implement quantization techniques or new dtypes, main motivation for us to recommend the tensor subclass based approach are three things: -(1). It’s natural for quantization to be modeled as a dtype conversion, so implementing it with tensor subclass means we are not introducing new concepts but reusing existing concepts like dtype, layout that already exists in pytorch core -(2). Since tensor subclass intercepts computation at torch function or aten ops level, as long as the same function/operator is used, we will be able to quantize the model. This allows the model that’s using variants of native modules (e.g. a slightly modified version of nn.Linear) to still be compatible with quantization -(3). Tensor subclass is also the approach adopted by other techniques like sparsity and distributed, so implementing quantization or dtype conversion with tensor subclass would make it easier for it to be composable with these techniques - -Example Code for a new DType -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Please feel free to start with `tutorial `__ for a end to end working example that combines everything we talked about together and come back to the doc for clarifications and documentations. - -Basic Structure -~~~~~~~~~~~~~~~ -A tensor subclass needs to define a few basic methods: ``__new__``, ``__init__``, ``__tensor_flatten__``, ``__tensor_unflatten__`` -and also dispatch functions for torch functions ``__torch_function__`` and aten ops ``__torch_dispatch__``. - -Here is an example of basic structure:: - # check out docs in https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L437 - from torchao.utils import TorchAOBaseTensor - - class MyDTypeLayout(TorchAOBaseTensor): - # see tutorial code for details - pass - - class MyDtypeTensor(TorchAOBaseTensor): - """We need to define `__new__` for constructing a new tensor subclass instance and `__init__` for initialize - the instance. There is no requirement on what the argument list should look like here, only requirement is - that `__new__` must return a Tensor instance with `torch.Tensor._make_wrapper_subclass(cls, shape, ...)` call - """ - @staticmethod - def __new__( - cls, - tensor_impl: MyDTypeLayout, - shape: torch.Size, - dtype: Optional[torch.dtype] = None, - ): - ... - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - tensor_impl: MyDTypeLayout, - shape: torch.Size, ... - ): - self.tensor_impl = tensor_impl - - - """`__tensor_flatten__` and `__tensor_unflatten__` are used to desugar the tensor into native Tensors/attributes and - reconstruct the tensor subclass instance from the desugared tensor and attributes, these are required to define - a Tensor subclass for torch.compile support - """ - def __tensor_flatten__(self): - return ["tensor_impl"], [self.shape] - - """see https://github.com/pytorch/pytorch/blob/3bc2004f9123a32f381ef64202252d59109507f3/torch/utils/_python_dispatch.py#L289 for documentations for outer_size and outer_stride - """ - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - tensor_impl = tensor_data_dict["tensor_impl"] - shape, = tensor_attributes - return cls( - tensor_impl, - shape if outer_size is None else outer_size, - ) - - - """classmethod that converts from a floating point Tensor (fp32/fp16/bf16) to the current dtype - """ - @classmethod - def from_float( - cls, - input_float: torch.Tensor, - ): - mapping_type = MappingType.SYMMETRIC - block_size = input_float.shape - dtype = torch.int16 - scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) - int_data = (input_float / scale).to(torch.int8) - tensor_impl = MyDTypeLayout.from_plain(int_data, scale) - return cls(tensor_impl, input_float.shape) - - - """[Optional] see docs for `Layout/Packing` under `Quantized Tensors` section to understand what layout_type is - """ - @property - def _layout(self) -> LayoutType: - return self.tensor_impl._layout - - """There are two entry points that we can modify the behavior of a pytorch op: torch_function and torch_dispatch: - - __torch_function__: will be called whenever a torch level function is called on the Tensor object, for example: torch.nn.functional.linear, - tensor.detach, tensor.reshape, tensor.t etc. - - __torch_dispatch__: will be called in the C++ dispatcher, when an aten operator is called on the Tensor object, for example: - aten.mm, aten.addmm, aten.detach.default, aten.t.default etc. - you can checkout https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L361-L389 to understand what `__torch_function__` and `__torch_dispatch__` are doing, but with `TorchAoBaseTensor` user can use - some helper functions directly (see next section) - -Operator Support -~~~~~~~~~~~~~~~~ -There are two types of operator support, torch function and aten ops. For torch functions (e.g. ``torch.nn.functional.linear``), we’ll need to overwrite ``__torch_function__`` callback in the Tensor subclass, for aten ops (e.g. ``torch.ops.aten.mm``), we’ll need to overwrite ``__torch_dispatch__`` callback function. - -For a new dtype, we’d like people to define the following decorator:: - if your dtype class is inherited from `torchao.utils.TorchAoBaseTensor`, you can do: - - implements = my_dtype_tensor_cls.implements - -And we can implement the operator dispatch with the following:: - # Example for torch_function dispatch for torch.nn.functional.linear - def _quantized_linear_op(input_tensor, weight_tensor, bias): - if isinstance(input_tensor, MyDtypeTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, MyDtypeTensor): - weight_tensor = weight_tensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - - - @implements(torch.nn.functional.linear) - def _(*args, **kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - # using try/except here so that we can have a general fallback when input_tensor/weight_tensor - # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to - # make the branches easier to understand in `_quantized_linear_op` - try: - return _quantized_linear_op(input_tensor, weight_tensor, bias) - except NotImplementedError: - if isinstance(input_tensor, MyDtypeTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, MyDtypeTensor): - weight_tensor = weight_tensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - - # Example for aten op dispatch for aten.detach.default - @implements(aten.detach.default) - def _(func, *args, **kwargs): - # `return_and_correct_aliasing` should be used by wrapper tensor ``__torch_dispatch__`` subclasses that would like to - # work with torch.compile. It ensures that the subclass properly implements the aliasing behavior of every op, - # which is needed for correctness in AOTAutograd. - - # `_apply_fn_to_data` just applies the function to the tensor data in `args[0]`, `args[0]` is a tensor subclass - # of `my_dtype` - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - -What ops do we need to overwrite? This depends on the model we are trying to quantize, commonly overwritten ops are: -``__torch_function__``: ``torch.nn.functional.linear`` -``__torch_dispatch__``: ``torch.ops.aten.addmm.default``, ``torch.ops.aten.mm.default``, ``torch.ops.aten.detach.default``, ``torch.ops.aten.t.default`` - -You can also find the ops that can be overwritten in ``__torch_function__`` or ``__torch_dispatch__`` with the following code, and you can start with a model that you want to optimize, start with just overwriting the important ops like linear, and gradually expand the coverage until the test runs and you get the expected optimized generated code (see Optimized Operators section for more details):: - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = torch.nn.Linear(10, 10) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) + x - - from torch.overrides import TorchFunctionMode - class TorchFunctionLoggingMode(TorchFunctionMode): - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - print(f"TORCH_FUNC={str(func)}") - return func(*args, **kwargs) - - with TorchFunctionLoggingMode(): - m(*example_inputs) - - ## Example output - # TORCH_FUNC= - # TORCH_FUNC= - - - from torch.utils._python_dispatch import TorchDispatchMode - class TorchDispatchLoggingMode(TorchDispatchMode): - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - print(f"ATEN_FUNC={str(func)}") - return func(*args, **kwargs) - - with TorchDispatchLoggingMode(): - m(*example_inputs) - - ## Example output - # ATEN_FUNC=aten.t.default - # ATEN_FUNC=aten.addmm.default - # ATEN_FUNC=aten.add.Tensor - - # or a more polished logging for torch_dispatch (aten) ops: https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py - -Alternatively, you can run a test example (e.g. use your quantized model with tensor parallelism, FSDP etc.) and discover the missing ops and add them until the test passes. - -We are still working on a table that talks about for each feature what are the operators that need to be supported. - Adding Efficient Kernels ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/index.rst b/docs/source/index.rst index 04a53ce454..f526c77939 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -37,3 +37,5 @@ for an overall introduction to the library and recent highlight and updates. :caption: Tutorials serialization + subclass_basic + subclass_advanced diff --git a/docs/source/subclass_advanced.rst b/docs/source/subclass_advanced.rst new file mode 100644 index 0000000000..f2df5a1cf0 --- /dev/null +++ b/docs/source/subclass_advanced.rst @@ -0,0 +1,4 @@ +Writing Your Own Quantized Tensor (advanced) +-------------------------------------------- + +Coming soon! diff --git a/docs/source/subclass_basic.rst b/docs/source/subclass_basic.rst new file mode 100644 index 0000000000..e007ea5bab --- /dev/null +++ b/docs/source/subclass_basic.rst @@ -0,0 +1,462 @@ +Writing Your Own Quantized Tensor +--------------------------------- + +Quantization in torchao is built on the foundation of tensor subclasses. +They are the main extension point for torchao to provide flexible +inference and training support using low precision computation, while +composing with important PyTorch features such as torch.compile, +autograd, and distributed primitives. + +In this tutorial, we will highlight the benefits of leveraging tensor +subclasses compared to module swaps, and walk through a simple example +of how to express quantization using this approach. + +What are Tensor Subclasses? +=========================== + +Tensor subclasses are simply classes that inherit from `torch.Tensor `__. +They allow users to interpose their custom computation logic between existing +ops in their models, such that functions in the top-level torch +namespace like torch.add will continue to work seamlessly. + +An obvious alternative to the tensor subclass approach is module swaps: +simply swap all nn.Linear modules in your model with your custom +Int8QuantizedLinear modules, for example. There are a few important +benefits of using tensor subclasses compared to this approach: + +1. **Finer-grained integration point.** Module swaps intercept + computation at the module level and so will not work for models that + rely on torch functions or variants of native modules (e.g. slightly + modified versions of nn.Linear). In contrast, since tensor subclasses + intercept computation at the function/op level, we will be able to + quantize the model as long as the same function/op is used. + +2. **Better composability.** Composing multiple features using module + swaps is clunky. For example, combining two existing + Int8QuantizedLinear and DistributedLinear modules would require users + to create another linear class that duplicates these functionalities. + Tensor subclasses bypass this problem by simply wrapping one subclass + in another. This can also offer performance benefits if the outer + tensor (e.g. `DTensor `__) + is aware that the inner tensor is quantized, and so can perform + expensive allgather operations using less network and memory + bandwidth. + +3. **Reusing PyTorch components.** It is natural to express quantization + using tensor subclasses since the quantized tensors are simply + torch.Tensors with different dtypes. The model structure does not + change (nn.Linears stay as nn.Linears), and so subsequent + optimization passes can also stay exactly the same as before. + +| +In the rest of the tutorial, we will walk through an example of how to +express quantization using both approaches. For further reading on +tensor subclasses, please refer to: + +- `Tensor subclass documentation `__ +- `Tensor subclass zoo `__ +- `Tensor subclass podcast by Edward Yang `__ + +Quantization with Module Swaps +============================== + +We begin with a simple example of how to implement int8 symmetric weight +only quantization using module swaps. All code can be found in this +`example script `__. +We will use the following function for quantizing float32 tensors into +int8 tensors: + +.. code:: py + + from typing import Tuple + import torch + + def int8_symmetric_quantize( + fp32_tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Symmetrically quantize the torch.float32 tensor into torch.int8. + Return a 2-tuple of (quantized value, scale). + + input: dimensions=[M, N], dtype=torch.float32 + output: dimensions=[M, N], dtype=torch.int8 + scale: dimensions=[M, 1], dtype=torch.float32 + """ + quant_min = -128 + quant_max = 127 + min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False) + max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = scale.view(fp32_tensor.shape[0], -1) + out = torch.round(fp32_tensor * (1.0 / scale)) + out = torch.clamp(out, quant_min, quant_max).to(torch.int8) + return out, scale + +Next, we will create a new QuantizedLinear module that calls this +function to dynamically quantize the weights: + +.. code:: py + + class QuantizedLinear(torch.nn.Linear): + """ + Linear module that performs dynamic and symmetric weight-only + int8 quantization. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + w_int8, scale = int8_symmetric_quantize(self.weight) + return torch.matmul(x, w_int8.t().to(x.dtype)) * scale.t() + + @classmethod + def from_float(cls, mod: torch.nn.Linear): + new_linear = cls(mod.in_features, mod.out_features, mod.bias) + new_linear.weight = mod.weight + return new_linear + +Then, the only thing that’s left is to swap all `nn.Linear` modules in the +model with our new QuantizedLinear. Let’s use the following toy model +for demonstration purposes: + +.. code:: py + + import copy + + class ToyModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + float_model = ToyModel(64, 128, 32).cuda() + quantized_model = copy.deepcopy(float_model) + + # Swap torch.nn.Linear with QuantizedLinear + for name, child in quantized_model.named_children(): + if type(child) == torch.nn.Linear: + new_linear = QuantizedLinear.from_float(child) + setattr(quantized_model, name, new_linear) + +Verify that the model now uses our QuantizedLinear module. This model is +now ready to use! + +.. code:: py + + >>> print(float_model) + ToyModel( + (linear1): Linear(in_features=64, out_features=128, bias=False) + (linear2): Linear(in_features=128, out_features=32, bias=False) + ) + + >>> print(quantized_model) + ToyModel( + (linear1): QuantizedLinear(in_features=64, out_features=128, bias=False) + (linear2): QuantizedLinear(in_features=128, out_features=32, bias=False) + ) + +An important drawback of this simple approach is flexibility. Currently +this only works for native PyTorch modules, but what if the model has +slightly modified linear modules that, for example, support distributed +training? It also won’t work with models that directly call the functional +version of linear (`torch.nn.functional.linear`) instead. + +Further, suppose we want to compose this feature with distribution, +which is also implemented through module swaps. There is no clean way to +do this except to create yet another module that combines both features. +These limitations can be solved with tensor subclasses, which is a more +elegant way to interpose custom computation such as quantization in your +model. + +Quantization with Tensor Subclasses +=================================== + +Here we are going to re-implement the above quantization technique, +using a `__torch_dispatch__`-based tensor subclass. + +Tensor subclasses (which often utilize `__torch_dispatch__`) are a pretty +powerful/flexible extension point in pytorch. They serve two main +purposes as an extension point: + +1) Tensor subclasses allow you to override the **implementation** of + (almost) every PyTorch API, and are used quite a bit to implement + other PyTorch offerings +2) Tensor subclasses allow you to **couple** your tensor data with + additional metadata. A few examples + + 1) [distributed] metadata on how a tensor is sharded across ranks + (`DTensor `__, + `docs `__) + 2) [quantization] scale/zero_point metadata + (`AffineQuantizedTensor `__) + 3) [raggedness] metadata on ragged structure + (`NestedTensor `__, + `docs `__) + +Some other resources on tensor subclasses for those who are interested: + +1) \__torch_dispatch_\_ docs + (`link `__) +2) What (and why) is \__torch_dispatch_\_ + (`link `__) +3) Google collab that implements a FlopCounter and MemoryTracker using + \__torch_dispatch_\_ + (`link `__) + +With that out of the way, let’s start by defining our bare-bones tensor +subclass for symmetric quantization: + +.. code:: py + + class Int8SymmetricTensor(torch.Tensor): + """ + Our subclass represents a tensor that has been quantized to int8 + It will hold two inner tensors: + int_data: int8[M, N] + scale: fp32[M, 1] + """ + + @staticmethod + @torch._dynamo.disable + def __new__(cls, int_data: torch.Tensor, scale: torch.Tensor): + return torch.Tensor._make_wrapper_subclass( + cls, + int_data.shape, + strides=int_data.stride(), + storage_offset=int_data.storage_offset(), + dtype=scale.dtype, + device=int_data.device, + ) + + @torch._dynamo.disable + def __init__(self, int_data: torch.Tensor, scale: torch.Tensor): + # inner data expected to be quantized already + assert int_data.dtype is torch.int8 + # we could do more work to support ndim > 2! + assert int_data.ndim == 2 + assert scale.ndim == 2 + self.int_data = int_data + self.scale = scale + + def __tensor_flatten__(self) -> Tuple[List[str], Any]: + """ + Returns a tuple of: + names of all inner tensor attributes (two in our case) + any other additional, non-tensor metadata. + + Needed for PT2 support. + """ + return ["int_data", "scale"], None + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, extra_metadata, outer_size=None, outer_stride=None): + """ + __tensor_unflatten__ should effectively undo __tensor_flatten__. + + inputs: + a dict mapping names of inner tensor attributes back to the tensors + the constant metadata from __tensor_flatten__ + output: + a new instance of your subclass + + Needed for PT2 support. + """ + assert extra_metadata is None + int_data = tensor_data_dict["int_data"] + scale = tensor_data_dict["scale"] + return Int8SymmetricTensor(int_data, scale) + + def __repr__(self): + return f'Int8SymmetricTensor(int_data={repr(self.int_data)}, scale={repr(self.scale)})' + + @staticmethod + def from_float(float_tensor): + """ + Actually performs the symmetric quantization. + In our simple inference example we will quantize weights "ahead-of-time", + although later in a training example we can quantize/dequantize + during model execution, inside of our __torch_dispatch__ + + input: + float32 torch.Tensor + output: + Int8SymmetricTensor + """ + int8_tensor, scale = int8_symmetric_quantize(float_tensor) + return Int8SymmetricTensor(int8_tensor, scale) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + """ + Called for each ATen operator that our subclass is passed as an input to. + We need to define our own implementation for every operator here. + """ + if kwargs is None: + kwargs = {} + if func not in op_implementations_dict: + raise AssertionError(f'Int8SymmetricTensor does not yet support op: {str(func)}') + return op_implementations_dict[func](func, *args, **kwargs) + + + # Convenience function for registering our own implementation + # to every ATen operator in PyTorch + op_implementations_dict = {} + def register_op(ops: List[torch._ops.OpOverload]): + def impl_decorator(op_impl): + global op_implementations_dict + for op in ops: + op_implementations_dict[op] = op_impl + return op_impl + + return impl_decorator + +In the above code, we have done a few things: + +1) Defined a basic “wrapper” tensor subclass - it is effectively a + container object, that holds some inner data (in particular, two + tensors that correspond to our int8 data and scales) +2) Defined a `__torch_dispatch__` implementation, which will be called + for every ATen operator our model calls on any of our subclass inputs +3) (For PT2 support) Defined a `__tensor_flatten__`/`__tensor_unflatten__` + method. This is the largest of a few requirements we have in order for + our subclass to work with torch.compile (more on this later). It + effectively tells `torch.compile` how to “desugar” our subclass into + its inner components. +4) (For PT2 support) Added a `torch._dynamo.disable` decorator to both + constructor methods (`__new__` and `__init__`) (more on this later). + +Which operators should we implement? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +PyTorch has a pretty large operator surface. Instead of trying to give +our new tensor subclass 100% coverage, let’s just focus on the ops we +need for our toy model above. + +Which operators are called in our model though, so we know what to +implement first? The brute force way is to repeatedly run the model +to see what ops error in your subclass. A more elegant way is to log +every operator that your model sees during execution. This can be +achieved through another `LoggingTensor` subclass as in `this example `__. + +Let's implement the necessary ops below: + +.. code:: py + + from torch.utils._python_dispatch import return_and_correct_aliasing + + @register_op([torch.ops.aten.mm.default]) + def int8_mm(func, x, weight): + assert isinstance(weight, Int8SymmetricTensor), "Int8SymmetricTensor: matmul currently only supports the weight in low precision, not the input!" + return torch.mm(x, weight.int_data.to(x.dtype)) * weight.scale + + @register_op([ + torch.ops.aten.detach.default, + torch.ops.aten.t.default, + ]) + def int8_view_ops(func, *args, **kwargs): + assert isinstance(args[0], Int8SymmetricTensor) + out_data = func(args[0].int_data, *args[1:], **kwargs) + out_scale = func(args[0].scale, *args[1:], **kwargs) + out = Int8SymmetricTensor(out_data, out_scale) + return return_and_correct_aliasing(func, args, kwargs, out) + +One thing you’ll notice quickly is: our model itself consists of a few +linear layers, but we see a few operations like `aten.t` and `aten.mm` +hitting our subclass. Some background: + +- We have a number of op decompositions that live in C++, that run + “above” tensor subclasses. `linear` is one such op (the decomp + lives `here `__) +- Decompositions can be good in the sense that they shrink the size of + the API that you as a subclass author have to implement. But they can + be painful if you would rather override the “higher level” operator + than the underlying operations in its decomposition. +- If you would prefer to override some operations (like Linear) at a + higher level, you can do so using `__torch_function__` + (`example `__). + It’s worth noting that if you want autograd support, then any + overrides you perform at the `__torch_function__` layer need to be + written in a way that is differentiable, while any overrides you + perform in `__torch_dispatch__` will be automatically differentiable. + +There are a few nuances in our implementations worth pointing out: + +1) You’ll notice that we no longer had to transpose our weight / scales + inside of our mm implementation. That’s because the transposition + “already happened” before we got to the `aten.mm` op. +2) Our `aten.mm` implementation does **not** return a tensor subclass + output. In that sense, the “propagation” of our quantized subclass + ends with matmuls. This maps to the fact that our weights are in low + precision, but we need to perform the matmuls themselves in high + precision. In general, subclass authors are free to choose for which + ops their subclasses do-or-do-not propagate. If you wanted every + function in your model to be quantized (including all pointwise and + reduction operations), you could write your subclass implementation + to quantize the output of every op and always return a subclass. +3) We were able to re-use the same implementation for 4 view operations. + In general, many ops might work with a pretty generic implementation: + unwrap any subclass inputs, run the underlying operator on the inner + tensor, and wrap the output back into a subclass. + + - Whether you can always re-use an implementation, though, depends + on what you are trying to do. For example, we implemented + `transpose(dim0, dim1)` on our subclass by calling the same + transpose on our inner data and inner scale tensor. This wouldn’t + work if our scale and data tensors had a different number of + dimensions, so transposition in that case would require a custom + implementation. + + +Comparing the Outputs +===================== + +And with all of that out of the way, let’s run our model with both +versions of quantization and confirm that they give the same output! + +.. code:: py + + float_model = ToyModel(64, 128, 32).cuda() + quantized_model_module_swap = copy.deepcopy(float_model) + quantized_model_subclass = copy.deepcopy(float_model) + + # Swap torch.nn.Linear with QuantizedLinear + for name, child in quantized_model_module_swap.named_children(): + if type(child) == torch.nn.Linear: + new_linear = QuantizedLinear.from_float(child) + setattr(quantized_model_module_swap, name, new_linear) + + # Swap torch.nn.Linear weights with Int8SymmetricTensor subclasses + for name, child in quantized_model_subclass.named_children(): + if type(child) == torch.nn.Linear: + subclass_param = Int8SymmetricTensor.from_float(child.weight) + child.weight = torch.nn.Parameter(subclass_param, requires_grad=True) + + with torch.no_grad(): + x = torch.randn(64, 64, 64, device='cuda') + out_module_swap = quantized_model_module_swap(x) + out = quantized_model_subclass(x) + print(torch.allclose(out, out_module_swap)) # prints True + + # We can also use torch.compile to fuse some of our quantized logic + out_compiled = torch.compile(quantized_model_subclass)(x) + print(torch.allclose(out, out_compiled)) # prints True + + +Next Steps +========== + +In this tutorial, we demonstrated how to build a simple quantized tensor +subclass. This is part one of two tutorials in this series. The +`next post `__ will discuss how to add more advanced +features to your tensor subclass, such as making it trainable, composing +with DTensors, and adding tensor parallelism support. For a more detailed +example of how `AffineQuantizedTensor` in torchao was built using tensor +subclasses, also check out `this example `__. + +If you have any questions while implementing your subclass, feel free to +file an issue `here `__. diff --git a/tutorials/examples/logging_subclass.py b/tutorials/examples/logging_subclass.py new file mode 100644 index 0000000000..ded50c56d6 --- /dev/null +++ b/tutorials/examples/logging_subclass.py @@ -0,0 +1,66 @@ +import torch +import torch.utils._pytree as pytree + + +class LoggingTensor(torch.Tensor): + @staticmethod + def __new__(cls, a): + return torch.Tensor._make_wrapper_subclass( + cls, + a.shape, + strides=a.stride(), + storage_offset=a.storage_offset(), + dtype=a.dtype, + device=a.device, + ) + + def __init__(self, a): + self.a = a + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + print("func: " + str(func)) + # Our logging subclass trivially implements *every* pytorch op. + # It does so by: + # - unwrapping any LoggingTensor arguments + # - calling the underlying function on the inner tensors + # - wrapping any tensor outputs into LoggingTensors + args_a = pytree.tree_map_only(LoggingTensor, lambda x: x.a, args) + kwargs_a = pytree.tree_map_only(LoggingTensor, lambda x: x.a, kwargs) + out_a = func(*args_a, **kwargs_a) + out_a_flat, spec = pytree.tree_flatten(out_a) + out_flat = [ + cls(o_a) if isinstance(o_a, torch.Tensor) else o_a for o_a in out_a_flat + ] + return pytree.tree_unflatten(out_flat, spec) + + +class ToyModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +if __name__ == "__main__": + # Set up toy model + float_model = ToyModel(64, 128, 32).cuda() + + # Replace any linear layer weights with our LoggingTensor + for name, child in float_model.named_children(): + if type(child) == torch.nn.Linear: + child.weight = torch.nn.Parameter( + LoggingTensor(child.weight), requires_grad=True + ) + + # run the model + with torch.no_grad(): + x = torch.randn(64, 64, 64, device="cuda") + _ = float_model(x) diff --git a/tutorials/examples/quantized_module_swap.py b/tutorials/examples/quantized_module_swap.py new file mode 100644 index 0000000000..07281a5bca --- /dev/null +++ b/tutorials/examples/quantized_module_swap.py @@ -0,0 +1,72 @@ +from typing import Tuple + +import torch + + +def int8_symmetric_quantize( + fp32_tensor: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Symmetrically quantize the torch.float32 tensor into torch.int8. + Return a 2-tuple of (quantized value, scale). + + input: dimensions=[M, N], dtype=torch.float32 + output: dimensions=[M, N], dtype=torch.int8 + scale: dimensions=[M, 1], dtype=torch.float32 + """ + quant_min = -128 + quant_max = 127 + min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False) + max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = scale.view(fp32_tensor.shape[0], -1) + out = torch.round(fp32_tensor * (1.0 / scale)) + out = torch.clamp(out, quant_min, quant_max).to(torch.int8) + return out, scale + + +class QuantizedLinear(torch.nn.Linear): + """ + Linear module that performs dynamic and symmetric weight-only + int8 quantization. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + w_int8, scale = int8_symmetric_quantize(self.weight) + return torch.matmul(x, w_int8.t().to(x.dtype)) * scale.t() + + @classmethod + def from_float(cls, mod: torch.nn.Linear): + new_linear = cls(mod.in_features, mod.out_features, mod.bias) + new_linear.weight = mod.weight + return new_linear + + +class ToyModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +if __name__ == "__main__": + # Set up toy model + model = ToyModel(64, 128, 32).cuda() + example_inputs = torch.randn((1, 64), dtype=torch.float32, device="cuda") + + # Swap torch.nn.Linear with QuantizedLinear + for name, child in model.named_children(): + if type(child) == torch.nn.Linear: + new_linear = QuantizedLinear.from_float(child) + setattr(model, name, new_linear) + + print("quantized model: ", model) + print("output: ", model(example_inputs)) diff --git a/tutorials/examples/quantized_subclass.py b/tutorials/examples/quantized_subclass.py new file mode 100644 index 0000000000..e256068294 --- /dev/null +++ b/tutorials/examples/quantized_subclass.py @@ -0,0 +1,183 @@ +import copy +from typing import Any, List, Tuple + +import torch + + +def int8_symmetric_quantize( + fp32_tensor: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Symmetrically quantize the torch.float32 tensor into torch.int8. + Return a 2-tuple of (quantized value, scale). + """ + quant_min = -128 + quant_max = 127 + min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False) + max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = scale.view(fp32_tensor.shape[0], -1) + out = torch.round(fp32_tensor * (1.0 / scale)) + out = torch.clamp(out, quant_min, quant_max).to(torch.int8) + return out, scale + + +# Our subclass represents a tensor that has been quantized to int8 +# It will hold two inner tensors: +# - int_data: int8[M, N] +# - scale: fp32[M, 1] +class Int8SymmetricTensor(torch.Tensor): + @staticmethod + @torch._dynamo.disable + def __new__(cls, int_data: torch.Tensor, scale: torch.Tensor): + return torch.Tensor._make_wrapper_subclass( + cls, + int_data.shape, + strides=int_data.stride(), + storage_offset=int_data.storage_offset(), + dtype=scale.dtype, + device=int_data.device, + ) + + @torch._dynamo.disable + def __init__(self, int_data: torch.Tensor, scale: torch.Tensor): + # inner data expected to be quantized already + assert int_data.dtype is torch.int8 + # we could do more work to support ndim > 2! + assert int_data.ndim == 2 + assert scale.ndim == 2 + self.int_data = int_data + self.scale = scale + + # __tensor_flatten__ returns a tuple of: + # - names of all inner tensor attributes (two in our case) + # - any other additional, non-tensor metadata. + def __tensor_flatten__(self) -> Tuple[List[str], Any]: + return ["int_data", "scale"], None + + # __tensor_unflatten__ should effectively undo __tensor_flatten__. + # inputs: + # - a dict mapping names of inner tensor attributes back to the tensors + # - the constant metadata from __tensor_flatten__ + # output: + # - a new instance of your subclass + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, extra_metadata, outer_size=None, outer_stride=None + ): + assert extra_metadata is None + int_data = tensor_data_dict["int_data"] + scale = tensor_data_dict["scale"] + return Int8SymmetricTensor(int_data, scale) + + def __repr__(self): + return f"Int8SymmetricTensor(int_data={repr(self.int_data)}, scale={repr(self.scale)})" + + # Actually performs the symmetric quantization. + # In our simple inference example we will quantize weights "ahead-of-time", + # although later in a training example we can quantize/dequantize + # during model execution, inside of our __torch_dispatch__ + # input: + # - float32 torch.Tensor + # output: + # - Int8SymmetricTensor + @staticmethod + def from_float(float_tensor): + int8_tensor, scale = int8_symmetric_quantize(float_tensor) + return Int8SymmetricTensor(int8_tensor, scale) + + # __torch_dispatch__ gets called for ATen operator + # that our subclass is passed as an input to. + # We need to define our own implementation for every operator here. + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + if func not in op_implementations_dict: + raise AssertionError( + f"Int8SymmetricTensor does not yet support op: {str(func)}" + ) + return op_implementations_dict[func](func, *args, **kwargs) + + +# Convenience function for registering our own implementation +# to every ATen operator in PyTorch +op_implementations_dict = {} + + +def register_op(ops: List[torch._ops.OpOverload]): + def impl_decorator(op_impl): + global op_implementations_dict + for op in ops: + op_implementations_dict[op] = op_impl + return op_impl + + return impl_decorator + + +from torch.utils._python_dispatch import return_and_correct_aliasing + + +# matmul impl +@register_op([torch.ops.aten.mm.default]) +def int8_mm(func, x, weight): + assert isinstance( + weight, Int8SymmetricTensor + ), "Int8SymmetricTensor: matmul currently only supports the weight in low precision, not the input!" + return torch.mm(x, weight.int_data.to(x.dtype)) * weight.scale + + +# implementation of most view operations +@register_op( + [ + torch.ops.aten.detach.default, + torch.ops.aten.t.default, + torch.ops.aten.view.default, + torch.ops.aten._unsafe_view.default, + ] +) +def int8_view_ops(func, *args, **kwargs): + assert isinstance(args[0], Int8SymmetricTensor) + out_data = func(args[0].int_data, *args[1:], **kwargs) + out_scale = func(args[0].scale, *args[1:], **kwargs) + out = Int8SymmetricTensor(out_data, out_scale) + # "return_and_correct_aliasing" here is needed for torch.compile support. + # It effectively tells the compiler that the output of this view op aliases its input. + # At some point, we're hoping to infer this automatically and kill this extra API! + return return_and_correct_aliasing(func, args, kwargs, out) + + +class ToyModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +if __name__ == "__main__": + # Set up toy model + float_model = ToyModel(64, 128, 32).cuda() + quantized_model_subclass = copy.deepcopy(float_model) + + # Swap torch.nn.Linear weights with Int8SymmetricTensor subclasses + for name, child in quantized_model_subclass.named_children(): + if type(child) == torch.nn.Linear: + subclass_param = Int8SymmetricTensor.from_float(child.weight) + child.weight = torch.nn.Parameter(subclass_param, requires_grad=True) + + with torch.no_grad(): + x = torch.randn(64, 64, 64, device="cuda") + out = quantized_model_subclass(x) + + # We can also use torch.compile to fuse some of our quantized logic + # run with TORCH_LOGS="output_code" to see the generated inductor code + out_compiled = torch.compile(quantized_model_subclass)(x) + print(torch.allclose(out, out_compiled))