From 9bee8a3842c6b8fbbc728c8443744d211cfc8249 Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 20 Jun 2024 16:22:06 -0700 Subject: [PATCH] updates to enable static weight quantization/dynamic activation quantization --- float8_experimental/float8_dynamic_linear.py | 22 +++++++++++++++++++- float8_experimental/float8_linear_utils.py | 9 +++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 701ae8a..3e0f5b8 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -72,14 +72,30 @@ def forward(self, x): y = cast_to_float8_e5m2_bw(y, self.backward_config) return y + def static_quantize_weight(self, dtype: torch.dtype = torch.float8_e4m3fn) -> None: + """Used to perform static_quantization, useful for inferenece where weights are not updated.""" + + scale = tensor_to_scale(self.weight, dtype) + quantized_weight = to_fp8_no_autograd( + self.weight, + scale, + dtype, + self.forward_config, + ) + self.weight = nn.Parameter(quantized_weight) + @classmethod - def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": + def from_float( + cls, mod, emulate: bool = False, static_quantize_weight: bool = False + ) -> "Float8DynamicLinear": """ Create an nn.Linear with fp8 compute from a regular nn.Linear Args: mod (torch.nn.Linear): nn.Linear to convert emulate (bool): whether to emulate fp8 matmul logic in float32 + static_quantize_weight (bool): whether to quantize the weight statically, this is useful + for inference where weights are not updated. """ with torch.device("meta"): super_kwargs = { @@ -96,6 +112,10 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": ) else: new_mod.weight = mod.weight + + if static_quantize_weight: + new_mod.static_quantize_weight() + new_mod.bias = mod.bias return new_mod diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 4d85045..61eef1e 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -6,7 +6,7 @@ import copy import logging from enum import auto, Enum -from typing import Callable, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type import torch import torch.distributed as dist @@ -100,6 +100,7 @@ def swap_linear_with_float8_linear( skip_fqn_list: Optional[List[str]] = None, emulate: bool = False, linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None, + from_float_kwargs: Dict[str, Any] = None, ) -> nn.Module: """ Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances @@ -122,7 +123,7 @@ def swap_linear_with_float8_linear( raise AssertionError( f"Does not support a root nn.Linear with children: {module}" ) - return module_cls.from_float(module, emulate=emulate) + return module_cls.from_float(module, emulate=emulate, **from_float_kwargs) # Mark all modules to skip as visited root_module = module @@ -146,7 +147,9 @@ def post_order_traversal( assert ( parent_module is not None ), f"Linear root module should return early: {module}" - float8linear_module = module_cls.from_float(module, emulate=emulate) + float8linear_module = module_cls.from_float( + module, emulate=emulate, **from_float_kwargs + ) setattr(parent_module, module_name, float8linear_module) post_order_traversal(root_module, "", None)