Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
updates to enable static weight quantization/dynamic activation quant…
Browse files Browse the repository at this point in the history
…ization
  • Loading branch information
drisspg committed Jun 20, 2024
1 parent edae9a3 commit 9bee8a3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
22 changes: 21 additions & 1 deletion float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,30 @@ def forward(self, x):
y = cast_to_float8_e5m2_bw(y, self.backward_config)
return y

def static_quantize_weight(self, dtype: torch.dtype = torch.float8_e4m3fn) -> None:
"""Used to perform static_quantization, useful for inferenece where weights are not updated."""

scale = tensor_to_scale(self.weight, dtype)
quantized_weight = to_fp8_no_autograd(
self.weight,
scale,
dtype,
self.forward_config,
)
self.weight = nn.Parameter(quantized_weight)

@classmethod
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
def from_float(
cls, mod, emulate: bool = False, static_quantize_weight: bool = False
) -> "Float8DynamicLinear":
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
Args:
mod (torch.nn.Linear): nn.Linear to convert
emulate (bool): whether to emulate fp8 matmul logic in float32
static_quantize_weight (bool): whether to quantize the weight statically, this is useful
for inference where weights are not updated.
"""
with torch.device("meta"):
super_kwargs = {
Expand All @@ -96,6 +112,10 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
)
else:
new_mod.weight = mod.weight

if static_quantize_weight:
new_mod.static_quantize_weight()

new_mod.bias = mod.bias
return new_mod

Expand Down
9 changes: 6 additions & 3 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy
import logging
from enum import auto, Enum
from typing import Callable, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -100,6 +100,7 @@ def swap_linear_with_float8_linear(
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
from_float_kwargs: Dict[str, Any] = None,
) -> nn.Module:
"""
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
Expand All @@ -122,7 +123,7 @@ def swap_linear_with_float8_linear(
raise AssertionError(
f"Does not support a root nn.Linear with children: {module}"
)
return module_cls.from_float(module, emulate=emulate)
return module_cls.from_float(module, emulate=emulate, **from_float_kwargs)

# Mark all modules to skip as visited
root_module = module
Expand All @@ -146,7 +147,9 @@ def post_order_traversal(
assert (
parent_module is not None
), f"Linear root module should return early: {module}"
float8linear_module = module_cls.from_float(module, emulate=emulate)
float8linear_module = module_cls.from_float(
module, emulate=emulate, **from_float_kwargs
)
setattr(parent_module, module_name, float8linear_module)

post_order_traversal(root_module, "", None)
Expand Down

0 comments on commit 9bee8a3

Please sign in to comment.