Skip to content

Commit

Permalink
remove prototype import
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 committed Nov 20, 2024
1 parent ae08ed7 commit be822ae
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 30 deletions.
2 changes: 0 additions & 2 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

from torchao.quantization import (
autoquant,
_autoquant_v2,
quantize_,
)
from . import dtypes
Expand All @@ -41,7 +40,6 @@
__all__ = [
"dtypes",
"autoquant",
"_autoquant_v2",
"quantize_",
"testing",
]
Expand Down
8 changes: 4 additions & 4 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def main(
from torchao.quantization import (
quantize_,
autoquant,
_autoquant_v2,
int8_weight_only,
int8_dynamic_activation_int8_weight,
int4_weight_only,
Expand All @@ -218,6 +217,7 @@ def main(
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
from torchao.utils import unwrap_tensor_subclass

from torchao.quantization.granularity import PerTensor, PerRow
Expand Down Expand Up @@ -330,11 +330,11 @@ def main(
)

if "autoquant_v2-int4" == quantization:
model = _autoquant_v2(model, manual=True, qtensor_class_list = torchao.quantization.V2_DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
elif "autoquant_v2-float8" == quantization:
model = _autoquant_v2(model, manual=True, qtensor_class_list = torchao.quantization.V2_OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
else:
model = _autoquant_v2(model, manual=True, example_input=inputs)
model = autoquant_v2(model, manual=True, example_input=inputs)

print("running generate")
generate(
Expand Down
8 changes: 4 additions & 4 deletions torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
int8_dynamic_activation_int8_weight,
int4_weight_only,
autoquant,
_autoquant_v2,
)
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
from torchao.sparsity import sparsify_, apply_fake_sparsity, semi_sparse_weight
from torchao.dtypes import SemiSparseLayout, MarlinSparseLayout
from torchao.utils import unwrap_tensor_subclass
Expand Down Expand Up @@ -347,11 +347,11 @@ def mlp_only(mod, name):
elif compress is not None and "autoquant_v2" in compress:
example_input = torch.randn(1, 3, 1024, 1024, dtype=torch.bfloat16, device=device)
if "autoquant_v2-int4" == compress:
_autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.V2_DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
elif "autoquant_v2-float8" == compress:
_autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.V2_OTHER_AUTOQUANT_CLASS_LIST)
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST)
else:
_autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True)
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True)

predictor.model.image_encoder(example_input)
predictor.model.image_encoder.finalize_autoquant()
Expand Down
29 changes: 26 additions & 3 deletions torchao/prototype/quantization/autoquant_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
debug_linears_for_float8,
prepare_target_folder,
)
from torchao.quantization.subclass import QuantizedLinearWeightBase
from torchao.quantization.autoquant import AutoQuantizableLinearWeight as AutoQuantizableLinearWeightV1
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization import LinearActivationQuantizedTensor

logging.basicConfig(level=logging.ERROR) # Set the root logger level to ERROR

Expand All @@ -61,8 +65,28 @@
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
"_is_linear",
]

def _is_linear(mod, *args):
# avoid circular dependencies
from torchao.quantization.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)

# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
# when it is shared by multiple linear modules
return (
isinstance(mod, torch.nn.Linear)
and hasattr(mod, "weight")
and not isinstance(mod.weight, QuantizedLinearWeightBase)
and not isinstance(mod.weight, AutoQuantizableLinearWeightV1)
and not isinstance(mod.weight, AffineQuantizedTensor)
and not isinstance(mod.weight, LinearActivationQuantizedTensor)
and not isinstance(mod.weight, AffineFakeQuantizedTensor)
and not isinstance(mod, torch.nn.modules.linear.NonDynamicallyQuantizableLinear)
)


# TODO: use SubgraphMatcher
def _graph_equals(g1, g2):
Expand All @@ -88,8 +112,7 @@ def _graph_equals(g1, g2):
# This is a flag to control whether we do some rewrite for graph
# to account for different batch sizes, it's a temporary solution for llama model
# we'll need to think about how to support this more generally
LLAMA = False

LLAMA = True

def check_cache(gm, cls, shapes_and_dtype):
for gm_, cls_, shapes_and_dtype_ in AUTOQUANT_CACHE.keys():
Expand Down Expand Up @@ -981,7 +1004,7 @@ def _change_linears_to_autoquantizable(
AutoQuantizableLinearWeight tensor subclass. Expectation is that this is followed
by running the model and then calling _change_autoquantizable_to_quantized
"""
from torchao.quantization.quant_api import _is_linear
# from torchao.quantization.quant_api import _is_linear

filter_fn = kwargs.pop("filter_fn", _is_linear)
_ = kwargs.pop(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ def maybe_short_name(torch_fn):


def get_meta_val(n: torch.fx.Node):
# from https://fburl.com/code/hcwdl994
# from https://github.com/pytorch/pytorch/blob/8d708090c0eb306facfd8f85d58c578a8cbbe689/torch/fx/graph.py#L644-L647
meta_val = n.meta.get(
"val", n.meta.get("tensor_meta", n.meta.get("example_value", None))
)
return meta_val


def get_stack_summary(n: torch.fx.Node):
# from https://fburl.com/code/yify7y7f
# from https://github.com/pytorch/pytorch/blob/8d708090c0eb306facfd8f85d58c578a8cbbe689/torch/fx/graph.py#L609
if n.stack_trace:
parsed_stack_trace = torch.fx.graph._parse_stack_trace(n.stack_trace)
summary = parsed_stack_trace.get_summary_str()
Expand Down
11 changes: 0 additions & 11 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@
OTHER_AUTOQUANT_CLASS_LIST,
autoquant,
)
from torchao.prototype.quantization.autoquant_v2 import (
DEFAULT_AUTOQUANT_CLASS_LIST as V2_DEFAULT_AUTOQUANT_CLASS_LIST,
DEFAULT_INT4_AUTOQUANT_CLASS_LIST as V2_DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
OTHER_AUTOQUANT_CLASS_LIST as V2_OTHER_AUTOQUANT_CLASS_LIST,
autoquant_v2 as _autoquant_v2,
)
from .GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
Expand Down Expand Up @@ -96,11 +90,6 @@
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
# experimental api
"_autoquant_v2",
"V2_DEFAULT_AUTOQUANT_CLASS_LIST",
"V2_DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"V2_OTHER_AUTOQUANT_CLASS_LIST",
# top level API - manual
"quantize_",
"int8_dynamic_activation_int4_weight",
Expand Down
4 changes: 0 additions & 4 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@
)

from .autoquant import AutoQuantizableLinearWeight, autoquant
from torchao.prototype.quantization.autoquant_v2 import AutoQuantizableLinearWeight as AutoQuantizableLinearWeightV2
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
from .GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
Expand Down Expand Up @@ -92,7 +90,6 @@
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"autoquant",
"autoquant_v2",
"_get_subclass_inserter",
"quantize_",
"int8_dynamic_activation_int4_weight",
Expand Down Expand Up @@ -252,7 +249,6 @@ def _is_linear(mod, *args):
and hasattr(mod, "weight")
and not isinstance(mod.weight, QuantizedLinearWeightBase)
and not isinstance(mod.weight, AutoQuantizableLinearWeight)
and not isinstance(mod.weight, AutoQuantizableLinearWeightV2)
and not isinstance(mod.weight, AffineQuantizedTensor)
and not isinstance(mod.weight, LinearActivationQuantizedTensor)
and not isinstance(mod.weight, AffineFakeQuantizedTensor)
Expand Down

0 comments on commit be822ae

Please sign in to comment.