Skip to content

Commit

Permalink
add axiswise scaling to Float8Linear
Browse files Browse the repository at this point in the history
Summary:

This PR: support scaling of all arguments of all gemms to be axiswise,
and ensure that training with axiswise scaling works e2e.

Future PR: support more granular configurability and optimize
performance, add docs

Test Plan:

```
// tests pass
./test/float8/test_everything.sh

// sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8:
// 1. verify performance does not regress with tensorwise scaling
// 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though
// logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 304a5427739966a9601fa860ed248fc2bb902d67
ghstack-comment-id: 2368837904
Pull Request resolved: #920
  • Loading branch information
vkuzo committed Sep 23, 2024
1 parent 5711a01 commit d759f81
Show file tree
Hide file tree
Showing 9 changed files with 450 additions and 41 deletions.
32 changes: 27 additions & 5 deletions benchmarks/float8/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

import torch
import torch.utils.benchmark as benchmark
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingType,
ScalingGranularity,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
linear_requires_sync,
Expand Down Expand Up @@ -107,35 +112,49 @@ def main(
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
):
device = "cuda"
print(f"Compile is set to | {compile}")

scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input=CastConfig(scaling_type=scaling_type_input)
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down Expand Up @@ -167,7 +186,7 @@ def main(
copy.deepcopy(linear_ref),
config=config,
)
scaling_repr = linear_float8.scaling_repr()
scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}"

if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
Expand Down Expand Up @@ -310,6 +329,7 @@ def invoke_main() -> None:
parser.add_argument("--scaling_type_input", type=str, required=False)
parser.add_argument("--scaling_type_weight", type=str, required=False)
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
parser.add_argument("--scaling_granularity", type=str, required=False)
args = parser.parse_args()
output_path = Path(args.output_path) if args.output_path is not None else None
kwargs = {}
Expand All @@ -327,6 +347,8 @@ def invoke_main() -> None:
kwargs["scaling_type_weight"] = args.scaling_type_weight
if args.scaling_type_grad_output is not None:
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
if args.scaling_granularity is not None:
kwargs["scaling_granularity"] = args.scaling_granularity
main(
output_path,
not args.disable_compile,
Expand Down
13 changes: 11 additions & 2 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch.nn as nn
import torch.utils.benchmark as benchmark

from torchao.float8.config import ScalingGranularity

from utils import (
get_name_to_shapes_iter,
profiler_output_to_filtered_time_by_kernel_name,
Expand Down Expand Up @@ -75,6 +77,7 @@ def run(
K: Optional[int] = None,
N: Optional[int] = None,
use_gpu_kernel_time: bool = False,
scaling_granularity: str = "tensorwise",
):
device = "cuda"

Expand All @@ -84,6 +87,7 @@ def run(
dtype = torch.bfloat16
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
fast_accum_vals = [True, False]
scaling_granularity = ScalingGranularity(scaling_granularity)

for idx, (fast_accum, (name, (M, K, N))) in enumerate(itertools.product(fast_accum_vals, name_to_shapes)):
if n_limit is not None and idx >= n_limit:
Expand All @@ -109,8 +113,13 @@ def run(
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
A = torch.zeros(M, K, device=device, dtype=d1)
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
if scaling_granularity == ScalingGranularity.TENSORWISE:
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
else:
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)

def do_matmul(A, B):
nonlocal scale_a
Expand Down
27 changes: 23 additions & 4 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingType,
ScalingGranularity,
)
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
Expand Down Expand Up @@ -252,6 +257,7 @@ def main(
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
model_type: str = "linear",
dtype_filter: str = "both",
add_inductor_metadata_to_trace: bool = True,
Expand All @@ -263,28 +269,41 @@ def main(
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input=CastConfig(scaling_type=scaling_type_input)
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down
33 changes: 30 additions & 3 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ def _test_linear_impl(
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@pytest.mark.parametrize(
"scaling_granularity",
[ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE],
)
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("linear_bias", [False, True])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
Expand All @@ -340,6 +344,7 @@ def test_linear(
scaling_type_input: ScalingType,
scaling_type_weight: ScalingType,
scaling_type_grad_output: ScalingType,
scaling_granularity: ScalingGranularity,
linear_dtype: torch.dtype,
linear_bias: bool,
):
Expand All @@ -352,30 +357,52 @@ def test_linear(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()
if scaling_granularity is ScalingGranularity.AXISWISE:
if (
scaling_type_input != ScalingType.DYNAMIC or
scaling_type_weight != ScalingType.DYNAMIC or
scaling_type_grad_output != ScalingType.DYNAMIC or
linear_dtype != torch.bfloat16 or
(not is_cuda_9_0)
):
pytest.skip()

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)

if scaling_type_input is ScalingType.STATIC:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_input = CastConfig(scaling_type=scaling_type_input)
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_weight = CastConfig(scaling_type=scaling_type_weight)
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output)
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down
Loading

0 comments on commit d759f81

Please sign in to comment.