Skip to content

Commit

Permalink
Refactor Affine Quantized Tensor (pytorch#1234)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored and sunjiweiswift committed Nov 25, 2024
1 parent 41b24e4 commit 1f33e64
Show file tree
Hide file tree
Showing 28 changed files with 2,438 additions and 2,230 deletions.
2 changes: 1 addition & 1 deletion test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_to_device(self, apply_quant):

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_register_new_dispatch(self):
from torchao.dtypes.affine_quantized_tensor import (
from torchao.dtypes.affine_quantized_tensor_ops import (
register_aqt_quantized_linear_dispatch,
deregister_aqt_quantized_linear_dispatch,
)
Expand Down
3 changes: 1 addition & 2 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
run_tests,
)
from torchao.dtypes.floatx import (
FloatxTensorCoreAQTTensorImpl,
FloatxTensorCoreLayout,
to_scaled_tc_floatx,
from_scaled_tc_floatx,
)
from torchao.dtypes.floatx.floatx import _pack_tc_floatx, _pack_tc_fp6
from torchao.dtypes.floatx.floatx_tensor_core_layout import _pack_tc_floatx, _pack_tc_fp6, FloatxTensorCoreAQTTensorImpl
from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32
from torchao.quantization import (
quantize_,
Expand Down
2 changes: 1 addition & 1 deletion test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torchao.dtypes.uint4 import (
from torchao.dtypes.uintx.uint4_layout import (
UInt4Tensor,
PerChannelSymmetricWeightUInt4Tensor,
)
Expand Down
2 changes: 1 addition & 1 deletion test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from torchao.dtypes.uintx import to_uintx
from torchao.dtypes.uintx.uintx_layout import to_uintx
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
Expand Down
7 changes: 1 addition & 6 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import unittest
import torch
from torchao.dtypes.affine_quantized_tensor import (
to_affine_quantized_intx,
from torchao.quantization import (
ZeroPointDomain,
PlainAQTTensorImpl,
PlainLayout,
TensorCoreTiledAQTTensorImpl,
TensorCoreTiledLayout,
MappingType,
)

Expand Down
2 changes: 1 addition & 1 deletion test/prototype/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_sparse(self, compile):
quantize_(model_copy, int8_dynamic_activation_int8_weight())
reference = model_copy(input)

from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout
from torchao.dtypes import BlockSparseLayout

quantize_(
model,
Expand Down
34 changes: 21 additions & 13 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from . import affine_quantized_tensor_ops
from .affine_quantized_tensor import (
AffineQuantizedTensor,
Float8AQTTensorImpl,
Float8Layout,
Layout,
MarlinQQQLayout,
MarlinSparseLayout,
PlainLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
MarlinQQQTensor,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
# experimental, will be merged into floatx in the future
Expand All @@ -16,15 +10,26 @@
to_affine_quantized_intx_static,
to_marlinqqq_quantized_intx,
)
from .floatx import (
Float8Layout,
)
from .nf4tensor import NF4Tensor, to_nf4

# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
from .uint4 import UInt4Tensor
from .uintx import (
BlockSparseLayout,
MarlinQQQLayout,
MarlinSparseLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
UintxLayout,
)
from .utils import (
Layout,
PlainLayout,
)

__all__ = [
"NF4Tensor",
"to_nf4",
"UInt4Tensor",
"AffineQuantizedTensor",
"to_affine_quantized_intx",
"to_affine_quantized_intx_static",
Expand All @@ -37,7 +42,10 @@
"SemiSparseLayout",
"TensorCoreTiledLayout",
"Float8Layout",
"Float8AQTTensorImpl",
"MarlinSparseLayout",
"affine_quantized_tensor_ops",
"BlockSparseLayout",
"UintxLayout",
"MarlinQQQTensor",
"MarlinQQQLayout",
]
Loading

0 comments on commit 1f33e64

Please sign in to comment.