diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 7447f5afc8..77616c1c6a 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -14,16 +14,16 @@ from torchao.dtypes.utils import is_device from torchao.float8.float8_utils import EPS as float8_eps from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, choose_qparams_affine, choose_qparams_affine_float8, dequantize_affine, dequantize_affine_float8, fake_quantize_affine, fake_quantize_affine_cachemask, - MappingType, quantize_affine, quantize_affine_float8, - ZeroPointDomain, ) # TODO: remove test for utils? @@ -34,11 +34,11 @@ quantize_activation_per_token_absmax, ) from torchao.utils import ( - is_fbcode, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + is_fbcode, ) _SEED = 1234 diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index ea372a3ceb..8b0ce28434 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -from enum import auto, Enum +from enum import Enum, auto from typing import Callable, Dict, List, Optional, Tuple, Union import torch @@ -16,11 +16,11 @@ _n_ones, ) from torchao.utils import ( - _is_float8_type, - _register_custom_op, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + _is_float8_type, + _register_custom_op, ) __all__ = [ @@ -1306,7 +1306,8 @@ def dequantize_affine_floatx( def choose_qparams_affine_float8( - tensor: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn, + tensor: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, ) -> torch.Tensor: """ Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.