Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

bring back torch.autograd.Function for float8 matmul #344

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,62 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
scale.copy_(new_scale)


# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
@torch._dynamo.allow_in_graph
class manual_float8_matmul(torch.autograd.Function):
"""
Like torch.matmul, but with the arguments in float8
"""

@staticmethod
def forward(
ctx,
input_fp8,
weight_fp8_t,
):
ctx.save_for_backward(input_fp8, weight_fp8_t)
# the reshapes are needed in order to make the shapes compatible with
# torch.mm
orig_shape = input_fp8.shape
input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
return res_bits

@staticmethod
def backward(ctx, grad_output_fp8):
input_fp8, weight_fp8_t = ctx.saved_tensors

# the reshapes are needed in order to make the shapes compatible with
# torch.mm
grad_output_fp8_orig_shape = grad_output_fp8.shape
grad_output_fp8_reshaped = grad_output_fp8.reshape(
-1, grad_output_fp8_orig_shape[-1]
)

# calculate grad_input
grad_input = torch.mm(
grad_output_fp8_reshaped,
weight_fp8_t.t(),
)
grad_input = grad_input.reshape(
*grad_output_fp8_orig_shape[:-1], grad_input.shape[-1]
)

input_fp8_orig_shape = input_fp8.shape
input_fp8_reshaped = input_fp8.reshape(-1, input_fp8_orig_shape[-1])

# calculate grad_weight
# Note: the variant below is slightly faster on LLaMa 3 8B pretraining
# compared to than calculating `grad_weight_t = input_fp8_t @ grad_output_fp8_reshaped`
grad_weight = torch.mm(
grad_output_fp8_reshaped.t(),
input_fp8_reshaped,
)

return grad_input, grad_weight.t()


@torch._dynamo.allow_in_graph
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
"""
Expand Down Expand Up @@ -393,7 +449,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized)

output = torch.matmul(input_fp8, weight_fp8.t())
output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())

# Cast grad_output to float8_e5m2 during backward
output = self.cast_output_to_float8_in_bw(output)
Expand Down
Loading