From b9b606e69a344494c1aa43ac5b917cc71825c9b1 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 25 Jul 2024 14:48:12 -0700 Subject: [PATCH] add per-gemm config to `Float8LinearConfig` (#334) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/334 Previously the per-gemm configuration had to be hardcoded in library code. This PR exposes it to the top-level UX by adding a `Float8GemmConfig` field to `Float8LinearConfig`. Note that today the only supported configuration option is `use_fast_accum`. In the future, configuring output_dtype and whether to keep a gemm in higher precision would go here. Reviewed By: weifengpy Differential Revision: D60252069 fbshipit-source-id: bca34eb49e1bf046f937e32b11b2369b535d56e6 --- float8_experimental/__init__.py | 2 ++ float8_experimental/config.py | 19 +++++++++++++++++++ float8_experimental/float8_linear.py | 18 +++++++++++------- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 4c1f255..08c0ac4 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -6,6 +6,7 @@ # Lets define a few top level things here from float8_experimental.config import ( DelayedScalingConfig, + Float8GemmConfig, Float8LinearConfig, Float8TensorCastConfig, TensorScalingType, @@ -33,6 +34,7 @@ # configuration "DelayedScalingConfig", "TensorScalingType", + "Float8GemmConfig", "Float8LinearConfig", "Float8TensorCastConfig", # top level UX diff --git a/float8_experimental/config.py b/float8_experimental/config.py index ea088e3..6408ac7 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -53,6 +53,17 @@ def __post_init__(self): ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." +@dataclass(frozen=True) +class Float8GemmConfig: + """ + Configuration for a float8 gemm. + """ + + # If True, fast accumulation in lower precision is used. + # Note: this flag is currently a no-op if emulation is turned on. + use_fast_accum: bool = False + + @dataclass(frozen=True) class Float8LinearConfig: """ @@ -67,6 +78,14 @@ class Float8LinearConfig: cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig() cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig() + # + # Per-gemm configuration for gemms calculating `output`, `grad_input` and + # `grad_weight` + # + gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) + gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() + gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig() + # # Per-linear configuration # diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 581f9f3..c598a93 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -168,24 +168,28 @@ def __init__(self, *args, **kwargs): self.create_buffers() - # TODO(future): user level configuration of gemms self.linear_mm_config = LinearMMConfig( - # input + # output ScaledMMConfig( emulate, - True if not emulate else False, + self.config.gemm_config_output.use_fast_accum, False, self.config.pad_inner_dim, ), - # weight + # grad_input ScaledMMConfig( emulate, - True if not emulate else False, + self.config.gemm_config_grad_input.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + # grad_weight + ScaledMMConfig( + emulate, + self.config.gemm_config_grad_weight.use_fast_accum, False, self.config.pad_inner_dim, ), - # grad_output - ScaledMMConfig(emulate, False, False, self.config.pad_inner_dim), ) # Note: is_amax_initialized is not a buffer to avoid data dependent