Skip to content

Commit

Permalink
[cleanup][4/x] unify weight casting
Browse files Browse the repository at this point in the history
Summary:

Not ready for review yet, performance regression because tensorwise
abs+max and weight casting is happening twice between fwd and bwd.
Limitation of something in PT2 stack?

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: e3da31f48634640b1b569228ad3a5d3964860acb
ghstack-comment-id: 2568319095
Pull Request resolved: #1481
  • Loading branch information
vkuzo committed Jan 8, 2025
1 parent 5b857dc commit 2291617
Showing 1 changed file with 34 additions and 61 deletions.
95 changes: 34 additions & 61 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,12 @@ def _get_weight_scale(
if tensor_already_casted_to_fp8(weight):
return None
assert scaling_type_weight is ScalingType.DYNAMIC
return tensor_to_scale(weight, config.cast_config_weight.target_dtype)


def _cast_weight_to_float8_t(
weight: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if tensor_already_casted_to_fp8(weight):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()
# inductor kernels for tensorwise max are faster when `weight` is
# contiguous.
# context: https://github.com/pytorch/pytorch/issues/144431
weight_t = weight.t()
assert weight_t.is_contiguous()
return tensor_to_scale(weight_t, config.cast_config_weight.target_dtype)


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -102,6 +89,33 @@ def forward(
weight_maybe_fp8_t = weight_hp_t
elif c.cast_config_weight.scaling_type is ScalingType.DISABLED:
weight_maybe_fp8_t = weight_hp_t
elif (
config.cast_config_weight.scaling_granularity
is ScalingGranularity.TENSORWISE
):
# Special case tensorwise scaling to allow the checkpointing of
# float8 casted weight, to prevent blowing up peak memory usage
# in FSDP.
weight_scale = _get_weight_scale(
weight_hp_t, config.cast_config_weight.scaling_type, config
)
if config.force_recompute_fp8_weight_in_bwd:
weight_maybe_fp8_t = checkpoint.checkpoint(
hp_tensor_and_scale_to_float8,
weight_hp_t,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
GemmInputRole.WEIGHT,
)
else:
weight_maybe_fp8_t = hp_tensor_and_scale_to_float8(
weight_hp_t,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
GemmInputRole.WEIGHT,
)
else:
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
weight_hp_t,
Expand Down Expand Up @@ -294,50 +308,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

has_any_axiswise_scaling = any(
cc.scaling_granularity is ScalingGranularity.AXISWISE
for cc in [
self.config.cast_config_input,
self.config.cast_config_weight,
self.config.cast_config_grad_output,
self.config.cast_config_input_for_grad_weight,
self.config.cast_config_weight_for_grad_input,
self.config.cast_config_grad_output_for_grad_weight,
]
)

weight_maybe_fp8_t = self.weight.t()

# TODO(future PR): check for axiswise scaling for input, weight,
# grad_output separately instead of together
if not has_any_axiswise_scaling:
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = _get_weight_scale(
self.weight, self.scaling_type_weight, self.config
)

if self.config.force_recompute_fp8_weight_in_bwd:
weight_fp8_t = checkpoint.checkpoint(
_cast_weight_to_float8_t,
self.weight,
self.config,
self.linear_mm_config,
weight_scale,
)
else:
weight_fp8_t = _cast_weight_to_float8_t(
self.weight,
self.config,
self.linear_mm_config,
weight_scale,
)

weight_maybe_fp8_t = weight_fp8_t

output = matmul_with_hp_or_float8_args.apply(
input,
weight_maybe_fp8_t,
self.weight.t(),
self.linear_mm_config,
self.config,
)
Expand Down

0 comments on commit 2291617

Please sign in to comment.