Skip to content

Commit

Permalink
add fused transpose and non-transpose kernel and use it for grad outp…
Browse files Browse the repository at this point in the history
…ut (#1497)
  • Loading branch information
danielvegamyhre authored Jan 8, 2025
1 parent 4996101 commit 070345d
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 19 deletions.
Empty file.
25 changes: 10 additions & 15 deletions torchao/prototype/float8nocompile/float8nocompile_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
ToFP8ColumnMajor,
ToFP8ColumnMajorT,
ToFP8RowAndColumnMajor,
ToFP8RowMajor,
ToFP8RowMajorT,
ToFP8RowMajorTAndNonT,
)
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
KernelAlgorithm,
Expand Down Expand Up @@ -138,12 +137,14 @@ def backward(ctx, grad_output):
input_fp8_col_major, weight_hp = ctx.saved_tensors

# cast grad output to float8_e5m2 for backward
grad_output_fp8_row_major = ToFP8RowMajor.apply(
grad_output,
ctx.config.cast_config_grad_output.target_dtype,
ctx.linear_mm_config,
GemmInputRole.GRAD_OUTPUT,
ctx.kernel_algo,
grad_output_fp8_row_major, grad_output_t_row_major = (
ToFP8RowMajorTAndNonT.apply(
grad_output,
ctx.config.cast_config_grad_output.target_dtype,
ctx.linear_mm_config,
GemmInputRole.GRAD_OUTPUT,
ctx.kernel_algo,
)
)

# grad_input = grad_output @ weight
Expand All @@ -159,12 +160,6 @@ def backward(ctx, grad_output):
# grad_weight = grad_output_t @ input
# apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output`
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
grad_output_t_row_major = ToFP8RowMajorT.apply(
grad_output,
ctx.config.cast_config_grad_output.target_dtype,
ctx.linear_mm_config,
GemmInputRole.GRAD_OUTPUT,
ctx.kernel_algo,
)
grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major)

return grad_input, grad_weight, None, None, None
36 changes: 32 additions & 4 deletions torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@

import torch

from torchao.float8.float8_tensor import (
GemmInputRole,
LinearMMConfig,
)
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
KernelAlgorithm,
hp_to_fp8_col_major,
hp_to_fp8_col_major_t,
hp_to_fp8_row_and_col_major,
hp_to_fp8_row_major,
hp_to_fp8_row_major_t,
hp_to_fp8_row_major_t_and_non_t,
)


Expand Down Expand Up @@ -172,3 +170,33 @@ def forward(
@staticmethod
def backward(ctx, g):
return g, None, None, None, None


class ToFP8RowMajorTAndNonT(torch.autograd.Function):
"""
A differentiable conversion to fp8.
* forward: convert from high precision to float8 and produces both row-major (transposed) and row-major (non-transposed) outputs
* backward: pass the gradient without changes
"""

@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
float8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
gemm_input_role: GemmInputRole,
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
):
fp8_row_major, fp8_row_major_t = hp_to_fp8_row_major_t_and_non_t(
tensor,
float8_dtype,
linear_mm_config,
gemm_input_role,
algo=kernel_algo,
)
return fp8_row_major, fp8_row_major_t

@staticmethod
def backward(ctx, g):
return g, None, None, None, None
Empty file.
158 changes: 158 additions & 0 deletions torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,82 @@ def _to_fp8_row_and_col_major(
tl.store(col_major_out_ptr + col_major_offs, fp8_vals, mask=mask)


@triton.autotune(
configs=kernel_configs_2D,
key=["num_elements"],
)
@triton.jit
def _to_fp8_row_major_t_and_non_t(
input_ptr,
row_major_out_ptr,
row_major_t_out_ptr,
scale_ptr,
num_elements: int,
fp8_dtype_min: float,
fp8_dtype_max: float,
input_num_rows: int,
input_num_cols: int,
input_stride_row: int,
input_stride_col: int,
row_major_out_stride_row: int,
row_major_out_stride_col: int,
row_major_t_out_stride_row: int,
row_major_t_out_stride_col: int,
input_dtype: tl.constexpr,
output_dtype: tl.constexpr,
BLOCK_SIZE_ROWS: tl.constexpr,
BLOCK_SIZE_COLS: tl.constexpr,
EPS: tl.constexpr,
):
"""
Reads a row-major, high precision input tensor and writes 2 output tensors:
1) fp8 row major tensor (transposed)
2) fp8 row major tensor
"""
block_row_id = tl.program_id(axis=0)
block_col_id = tl.program_id(axis=1)

# load scaling factor
scale = tl.load(scale_ptr).to(tl.float32)

# load block of input tensor
block_row_start = block_row_id * BLOCK_SIZE_ROWS
block_col_start = block_col_id * BLOCK_SIZE_COLS
block_row_offs = block_row_start + tl.arange(0, BLOCK_SIZE_ROWS)
block_col_offs = block_col_start + tl.arange(0, BLOCK_SIZE_COLS)
input_offs = (
block_row_offs[:, None] * input_stride_row
+ block_col_offs[None, :] * input_stride_col
)
mask = (block_row_offs[:, None] < input_num_rows) & (
block_col_offs[None, :] < input_num_cols
)
vals = tl.load(input_ptr + input_offs, mask=mask).to(input_dtype)

# perform conversion
vals = vals * scale
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype)

# write row-major output
row_major_offs = (
block_row_offs[:, None] * row_major_out_stride_row
+ block_col_offs[None, :] * row_major_out_stride_col
)
tl.store(row_major_out_ptr + row_major_offs, fp8_vals, mask=mask)

# write tranposed row-major output
row_major_t_num_rows = input_num_cols
row_major_t_num_cols = input_num_rows
row_major_t_offs = (
block_col_offs[:, None] * row_major_t_out_stride_row
+ block_row_offs[None, :] * row_major_t_out_stride_col
)
mask = (block_row_offs[:, None] < row_major_t_num_rows) & (
block_col_offs[None, :] < row_major_t_num_cols
)
tl.store(row_major_t_out_ptr + row_major_t_offs, fp8_vals.trans(1, 0), mask=mask)


@triton.autotune(configs=kernel_configs_1D, key=["num_elements"])
@triton.jit
def _amax_atomic(
Expand Down Expand Up @@ -701,6 +777,88 @@ def hp_to_fp8_row_and_col_major(
return fp8_tensor_row_major, fp8_tensor_col_major


def hp_to_fp8_row_major_t_and_non_t(
hp_tensor: torch.Tensor,
fp8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
) -> Float8Tensor:
assert hp_tensor.is_contiguous(), "input tensor must be contiguous"

tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype]

fp8_dtype_min = torch.finfo(fp8_dtype).min
fp8_dtype_max = torch.finfo(fp8_dtype).max

# compute scaling factor for tensor
scale = _hp_tensor_to_scale(
hp_tensor,
tl_input_dtype,
fp8_dtype_max,
algo,
)

# perform fp8 conversion
input_num_rows, input_num_cols = hp_tensor.shape
transposed_num_rows, transposed_num_cols = input_num_cols, input_num_rows
num_elements = hp_tensor.numel()

# preallocate necessary output tensors
fp8_output_row_major = torch.empty(
(input_num_rows, input_num_cols), dtype=fp8_dtype, device=hp_tensor.device
)
fp8_output_row_major_t = torch.empty(
(transposed_num_rows, transposed_num_cols),
dtype=fp8_dtype,
device=hp_tensor.device,
)

# launch triton kernel to perform conversion
grid = lambda meta: (
triton.cdiv(input_num_rows, meta["BLOCK_SIZE_ROWS"]),
triton.cdiv(input_num_cols, meta["BLOCK_SIZE_COLS"]),
)
_to_fp8_row_major_t_and_non_t[grid](
hp_tensor,
fp8_output_row_major,
fp8_output_row_major_t,
scale,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_num_rows,
input_num_cols,
hp_tensor.stride(0),
hp_tensor.stride(1),
fp8_output_row_major.stride(0),
fp8_output_row_major.stride(1),
fp8_output_row_major_t.stride(0),
fp8_output_row_major_t.stride(1),
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
EPS=EPS,
)

# wrap outputs in Float8Tensors
fp8_tensor_row_major = Float8Tensor(
fp8_output_row_major,
scale,
orig_dtype=hp_tensor.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)
fp8_tensor_row_major_t = Float8Tensor(
fp8_output_row_major_t,
scale,
orig_dtype=hp_tensor.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)
return fp8_tensor_row_major, fp8_tensor_row_major_t


def _hp_tensor_to_scale(
hp_tensor: torch.Tensor,
tl_input_dtype: tl.core.dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
hp_to_fp8_row_and_col_major,
hp_to_fp8_row_major,
hp_to_fp8_row_major_t,
hp_to_fp8_row_major_t_and_non_t,
)


Expand Down Expand Up @@ -335,3 +336,77 @@ def test_fp8_hp_to_fp8_row_and_col_major(
torch.float8_e4m3fn,
LinearMMConfig(),
)


@pytest.mark.parametrize(
"algo",
[KernelAlgorithm.REDUCTION, KernelAlgorithm.ATOMIC_MAX],
)
@pytest.mark.parametrize(
"input_shape",
[(2, 4), (32, 16), (512, 512)],
)
def test_fp8_hp_to_fp8_row_major_t_and_non_t(
input_shape: tuple[int, int], algo: KernelAlgorithm
):
assert torch.cuda.is_available()
device = "cuda"
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
x_bf16 = input_bf16.clone().detach().to(device)
y_bf16 = input_bf16.clone().detach().to(device)

# production implementation
x_fp8_row_major = hp_tensor_to_float8_dynamic(
x_bf16,
torch.float8_e4m3fn,
LinearMMConfig(),
)
x_fp8_row_major_t = x_fp8_row_major.t().contiguous()

# float8nocompile triton implementation
y_fp8_row_major, y_fp8_row_major_t = hp_to_fp8_row_major_t_and_non_t(
y_bf16,
torch.float8_e4m3fn,
LinearMMConfig(),
algo=algo,
)

# check scales
assert torch.eq(x_fp8_row_major._scale, y_fp8_row_major._scale)
assert torch.eq(x_fp8_row_major_t._scale, y_fp8_row_major_t._scale)

# check data
assert torch.all(torch.eq(x_fp8_row_major._data, y_fp8_row_major._data))
assert torch.all(torch.eq(x_fp8_row_major_t._data, y_fp8_row_major_t._data))

# check shapes
assert x_fp8_row_major.shape == y_fp8_row_major.shape
assert x_fp8_row_major_t.shape == y_fp8_row_major_t.shape

# check strides
assert x_fp8_row_major.stride() == y_fp8_row_major.stride()
assert x_fp8_row_major_t.stride() == y_fp8_row_major_t.stride()

# check memory layout
assert is_row_major(x_fp8_row_major.stride())
assert is_row_major(y_fp8_row_major.stride())
assert is_row_major(x_fp8_row_major_t.stride())
assert is_row_major(y_fp8_row_major_t.stride())

# check underlying memory layout
assert (
x_fp8_row_major._data.storage().tolist()
== y_fp8_row_major._data.storage().tolist()
)
assert (
x_fp8_row_major_t._data.storage().tolist()
== y_fp8_row_major_t._data.storage().tolist()
)

# assert that error is raised when input tensor is not contiguous
with pytest.raises(AssertionError, match="tensor must be contiguous"):
hp_to_fp8_row_major_t_and_non_t(
y_bf16.t(), # transpose so tensor memory layout is no longer contiguous
torch.float8_e4m3fn,
LinearMMConfig(),
)

0 comments on commit 070345d

Please sign in to comment.