Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat]: Add support for kleidiai quantization schemes #1447

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ng-05
Copy link

@ng-05 ng-05 commented Dec 19, 2024

Description:

  1. Allow Int4WeightOnlyQuantizer to work with channelwise and groupwise symmetric quantization schemes
  2. KleidiAI supports channelwise and 32 groupwise quantized matmul kernels

Needs : pytorch/pytorch#134124

Copy link

pytorch-bot bot commented Dec 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1447

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @ng-05!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@ng-05 ng-05 marked this pull request as draft December 19, 2024 10:44
@ng-05
Copy link
Author

ng-05 commented Jan 8, 2025

Hello @jerryzh168 ,
We want to support two diff type of int4 schemes.

  1. symmetric_groupwise -> groupsize [ 32, 64, 128 etc ]
  2. symmetric_channelwise -> groupsize is equal to channelsize of the matmul weights

How should we take this input from user regarding quantization schemes. Groupsize parameter can not server the purpose as channelsize will change for diff matmuls in a model?

Currently I am using "scheme" parameter to differentiate between the two.
aarch64_cpu_channelwise.json
aarch64_cpu_groupwise.json

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jan 8, 2025

How should we take this input from user regarding quantization schemes. Groupsize parameter can not server the purpose as channelsize will change for diff matmuls in a model?

yeah, you can use https://github.com/pytorch/ao/blob/main/torchao/quantization/granularity.py: PerGroup and PerAxis(axis=0) (assuming channel dimension is 0), examples:

granularity: Optional[
,
weight_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.float32)

@ng-05
Copy link
Author

ng-05 commented Jan 9, 2025

How should we take this input from user regarding quantization schemes. Groupsize parameter can not server the purpose as channelsize will change for diff matmuls in a model?

yeah, you can use https://github.com/pytorch/ao/blob/main/torchao/quantization/granularity.py: PerGroup and PerAxis(axis=0) (assuming channel dimension is 0), examples:

granularity: Optional[

,

weight_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.float32)

Thanks for the inputs @jerryzh168.

I have initial change ready which extends int4_weight_only quantizer.

The 4 bit KleidiAI kernels quantizes the weight in torchao and input to 8 bit within the kernel itself instead of quantizing the input in the torchao the way int8_dynamic_activation_int4_weight does.
For this reason I am extending the int4_weight_only api. I am slightly confused if the intention of this api is to convey NO input quantisation to user?

Currently neither int4_weight_only nor int8_dynamic_activation_int4_weight fully aligns with the way kelidiai 4 bit kernels are working.

I feel int4_weight_only is closest to what we want to do, what are your thoughts on this?

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jan 9, 2025

I feel int4_weight_only is closest to what we want to do, what are your thoughts on this?

yeah int4_weight_only means no input quantization, I think it aligns better with int8_dynamic_activation_int4_weight, you can use a different layout and customize the logic for input quantization.

we also have

def int8_dynamic_activation_intx_weight(
that is the same as your use case. there is some ongoing refactors/updates there as well right now

You can also check out: #995

ng-05 added 2 commits January 11, 2025 01:32
Description:
1. Allow Int4WeightOnlyQuantizer to work with channelwise and groupwise
symmetric quantization schemes
2. KleidiAI supports channelwise and 32 groupwise quantized matmul
   kernels

Signed-off-by: Nikhil Gupta <[email protected]>
@ng-05
Copy link
Author

ng-05 commented Jan 11, 2025

Hello @jerryzh168 , I am planning to migrate int8_dynamic_activation_intx_weight api to int8_dynamic_activation_intx_weight_v2.
For now I have kept the API separate for review and testing.

Can you please review this change, specially the change the in _get_linear_subclass_inserter which allow bias propagation. The bias needed by torch.ops.aten._dyn_quant_pack_4bit_weight.

I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

target: Target

# Allow bias access via layout
bias: Optional[torch.Tensor] = None
Copy link
Contributor

@jerryzh168 jerryzh168 Jan 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layout is more of a "type" actually, why is bias Tensor passed here?

the corresponding "storage" is TensorImpl

"""Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs)
to the weight of linear module
"""

def insert_subclass(lin):
requires_grad = allow_requires_grad and lin.weight.requires_grad
args = [lin.weight]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I feel putting optional args in kwargs might be better

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me overall, can you add some tests?

@jerryzh168
Copy link
Contributor

I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

I don't think we need to expose these fine grained args to torchchat cli, we just need these high level args like: https://github.com/pytorch/torchchat/blob/main/torchchat/quant_config/mobile.json

we are also working on migrating torchchat to use torchao quant api btw

@@ -100,6 +110,12 @@ def _pack_weights_native(
torch.empty(0, group_size, dtype=torch.int8),
]

if TORCH_VERSION_AT_LEAST_2_6 and layout.target == Target.ATEN:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If torch version is not 2.6 but layout.target == aten, then what happens? Should you just assert that it is not supported?

), "Target.ATEN requires torch >= 2.6.0"
# aten supports bias for kleidiAI but not for default fallback op
if not torch.backends.kleidiai.is_available():
print("TODO bias == None")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert bias == None,

return torch.ops.aten._dyn_quant_matmul_4bit(
input_tensor, packed_weight, group_size, k_, n)

if input_tensor.dim() == 2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have this requirement?

_intx_granularity = Union[PerGroup, PerRow]


def int8_dynamic_activation_intx_weight_v2(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm currently refactoring int8_dynamic_activation_intx_weight quantizer to use layout instead of target for the packing format: #1553. I think this should provide more flexibility longterm.

@@ -153,7 +169,7 @@ def get_layout(self) -> Layout:
def get_plain(
self,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.get_layout().target == Target.FALLBACK:
if self.get_layout().target == Target.FALLBACK or self.get_layout().target == Target.ATEN:
return self.packed_weight, self.scale, self.zero_point
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC when using Target.ATEN, self.packed_weight is not the int_data, so I'm not sure get_plain is correct here?

_intx_granularity = Union[PerGroup, PerRow]


def int8_dynamic_activation_intx_weight_v2(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a significant diffrence between this and int8_dynamic_activation_intx_weight?

@metascroy
Copy link
Contributor

Hello @jerryzh168 , I am planning to migrate int8_dynamic_activation_intx_weight api to int8_dynamic_activation_intx_weight_v2. For now I have kept the API separate for review and testing.

Can you please review this change, specially the change the in _get_linear_subclass_inserter which allow bias propagation. The bias needed by torch.ops.aten._dyn_quant_pack_4bit_weight.

I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

torchchat does not currently use int8_dynamic_activation_intx_weight, but instead a submodule swap API here: https://github.com/pytorch/ao/blob/main/torchao/experimental/quant_api.py#L438

We will be switching torchchat to use int8_dynamic_activation_intx_weight instead, but I first need to land some changes for perf/clarity: #1553

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this quant API now connects kernels we landed in aten with quant API. If the kernels you guys landed in aten are actually new ops, unlike int4pack_mm and friends, then why did we land them there in the first place. In order to reach those kernels you need ao dep anyway? (@digantdesai I know you tagged me on that PR but i never really deep dived into that so maybe you have context here)

Besides taht i have a couple of questions.

  • In the current form it is only making aten op you guys added available via tensor subclass api, so what happens to say torch.compile (maybe this works?) or AOTI usecase?
  • I would also like to see if we can leverage this op in executorch, for which integration into AO would have been a better choice compared to this being aten op
  • If kleidi's op performs better than whats in this repo (and note that @digantdesai has actually integrated some of the kleidi ops that I guess you guys are aware of), then can we just use that op directly or have a path to kleidi's impl for the cpu ops that exist under experimental/ops?

@kimishpatel
Copy link
Contributor

Hello @jerryzh168 , I am planning to migrate int8_dynamic_activation_intx_weight api to int8_dynamic_activation_intx_weight_v2. For now I have kept the API separate for review and testing.
Can you please review this change, specially the change the in _get_linear_subclass_inserter which allow bias propagation. The bias needed by torch.ops.aten._dyn_quant_pack_4bit_weight.
I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

torchchat does not currently use int8_dynamic_activation_intx_weight, but instead a submodule swap API here: main/torchao/experimental/quant_api.py#L438

We will be switching torchchat to use int8_dynamic_activation_intx_weight instead, but I first need to land some changes for perf/clarity: #1553

Any specific reason why use subclass API instead of module swap?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants