Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
danielvegamyhre committed Jan 24, 2025
1 parent fd57fdf commit c6cdbab
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
6 changes: 3 additions & 3 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
from torchao.dtypes.utils import is_device
from torchao.float8.float8_utils import EPS as float8_eps
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_float8,
dequantize_affine,
dequantize_affine_float8,
fake_quantize_affine,
fake_quantize_affine_cachemask,
MappingType,
quantize_affine,
quantize_affine_float8,
ZeroPointDomain,
)

# TODO: remove test for utils?
Expand All @@ -34,11 +34,11 @@
quantize_activation_per_token_absmax,
)
from torchao.utils import (
is_fbcode,
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_fbcode,
)

_SEED = 1234
Expand Down
9 changes: 5 additions & 4 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import math
from enum import auto, Enum
from enum import Enum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -16,11 +16,11 @@
_n_ones,
)
from torchao.utils import (
_is_float8_type,
_register_custom_op,
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
_is_float8_type,
_register_custom_op,
)

__all__ = [
Expand Down Expand Up @@ -1306,7 +1306,8 @@ def dequantize_affine_floatx(


def choose_qparams_affine_float8(
tensor: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn,
tensor: torch.Tensor,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
) -> torch.Tensor:
"""
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
Expand Down

0 comments on commit c6cdbab

Please sign in to comment.