Skip to content

Commit

Permalink
[cleanup][4/x] unify weight casting
Browse files Browse the repository at this point in the history
Summary:

Not ready for review yet, performance regression because tensorwise
abs+max and weight casting is happening twice between fwd and bwd.
Limitation of something in PT2 stack?

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: d013d859f3f4230e28207e70b8aafcfd907d5c45
ghstack-comment-id: 2568319095
Pull Request resolved: #1481
  • Loading branch information
vkuzo committed Jan 8, 2025
1 parent f99bd4b commit 231bb13
Showing 1 changed file with 35 additions and 71 deletions.
106 changes: 35 additions & 71 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,6 @@
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor


def _get_weight_scale(
weight: torch.Tensor,
scaling_type_weight: ScalingType,
config: Float8LinearConfig,
) -> Optional[torch.Tensor]:
if tensor_already_casted_to_fp8(weight):
return None
assert scaling_type_weight is ScalingType.DYNAMIC
return tensor_to_scale(weight, config.cast_config_weight.target_dtype)


def _cast_weight_to_float8_t(
weight: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if tensor_already_casted_to_fp8(weight):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()


@torch._dynamo.allow_in_graph
class matmul_with_hp_or_float8_args(torch.autograd.Function):
"""
Expand Down Expand Up @@ -102,6 +73,40 @@ def forward(
weight_maybe_fp8_t = weight_hp_t
elif c.cast_config_weight.scaling_type is ScalingType.DISABLED:
weight_maybe_fp8_t = weight_hp_t
elif (
config.cast_config_weight.scaling_granularity
is ScalingGranularity.TENSORWISE
):
# Special case tensorwise to allow the checkpointing of float8
# casted weight, to prevent blowing up peak memory usage in FSDP.

# inductor kernels for tensorwise max are faster when `weight` is
# contiguous.
# context: https://github.com/pytorch/pytorch/issues/144431
weight_hp_t_t = weight_hp_t.t()
assert weight_hp_t_t.is_contiguous()
weight_scale = tensor_to_scale(
weight_hp_t_t, config.cast_config_weight.target_dtype
)

if config.force_recompute_fp8_weight_in_bwd:
weight_maybe_fp8_t = checkpoint.checkpoint(
hp_tensor_and_scale_to_float8,
weight_hp_t,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
GemmInputRole.WEIGHT,
)
else:
weight_maybe_fp8_t = hp_tensor_and_scale_to_float8(
weight_hp_t,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
GemmInputRole.WEIGHT,
)

else:
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
weight_hp_t,
Expand Down Expand Up @@ -294,50 +299,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

has_any_axiswise_scaling = any(
cc.scaling_granularity is ScalingGranularity.AXISWISE
for cc in [
self.config.cast_config_input,
self.config.cast_config_weight,
self.config.cast_config_grad_output,
self.config.cast_config_input_for_grad_weight,
self.config.cast_config_weight_for_grad_input,
self.config.cast_config_grad_output_for_grad_weight,
]
)

weight_maybe_fp8_t = self.weight.t()

# TODO(future PR): check for axiswise scaling for input, weight,
# grad_output separately instead of together
if not has_any_axiswise_scaling:
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = _get_weight_scale(
self.weight, self.scaling_type_weight, self.config
)

if self.config.force_recompute_fp8_weight_in_bwd:
weight_fp8_t = checkpoint.checkpoint(
_cast_weight_to_float8_t,
self.weight,
self.config,
self.linear_mm_config,
weight_scale,
)
else:
weight_fp8_t = _cast_weight_to_float8_t(
self.weight,
self.config,
self.linear_mm_config,
weight_scale,
)

weight_maybe_fp8_t = weight_fp8_t

output = matmul_with_hp_or_float8_args.apply(
input,
weight_maybe_fp8_t,
self.weight.t(),
self.linear_mm_config,
self.config,
)
Expand Down

0 comments on commit 231bb13

Please sign in to comment.