Skip to content

Commit

Permalink
[cleanup][4/x] unify weight casting
Browse files Browse the repository at this point in the history
Summary:

Not ready for review yet, performance regression because tensorwise
abs+max and weight casting is happening twice between fwd and bwd.
Limitation of something in PT2 stack?

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 27996f8401a77ca2fc5fdf1bb2b200d3b9fd41a7
ghstack-comment-id: 2568319095
Pull Request resolved: #1481
  • Loading branch information
vkuzo committed Jan 8, 2025
1 parent 8c73d60 commit 945c578
Showing 1 changed file with 47 additions and 15 deletions.
62 changes: 47 additions & 15 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,30 @@ def _get_weight_scale(
if tensor_already_casted_to_fp8(weight):
return None
assert scaling_type_weight is ScalingType.DYNAMIC
return tensor_to_scale(weight, config.cast_config_weight.target_dtype)
# inductor kernels for tensorwise max are faster when `weight` is
# contiguous.
# context: https://github.com/pytorch/pytorch/issues/144431
weight_t = weight.t()
assert weight_t.is_contiguous()
return tensor_to_scale(weight_t, config.cast_config_weight.target_dtype)


def _cast_weight_to_float8_t(
def _cast_weight_to_float8(
weight: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if tensor_already_casted_to_fp8(weight):
return weight.t()
return weight
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()
return weight_fp8


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -103,16 +108,43 @@ def forward(
elif c.cast_config_weight.scaling_type is ScalingType.DISABLED:
weight_maybe_fp8_t = weight_hp_t
else:
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
weight_hp_t,
c.cast_config_weight.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=c.cast_config_weight.scaling_granularity,
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_weight.scaling_granularity
),
)
# non-axiswise
if (
config.cast_config_weight.scaling_granularity
is ScalingGranularity.TENSORWISE
):
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = _get_weight_scale(
weight_hp_t, config.cast_config_weight.scaling_type, config
)

if config.force_recompute_fp8_weight_in_bwd:
weight_maybe_fp8_t = checkpoint.checkpoint(
_cast_weight_to_float8,
weight_hp_t,
config,
linear_mm_config,
weight_scale,
)
else:
weight_maybe_fp8_t = _cast_weight_to_float8(
weight_hp_t,
config,
linear_mm_config,
weight_scale,
)
else:
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
weight_hp_t,
c.cast_config_weight.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=c.cast_config_weight.scaling_granularity,
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_weight.scaling_granularity
),
)

# the reshapes are needed in order to make the shapes compatible with
# torch.mm
Expand Down Expand Up @@ -310,7 +342,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

# TODO(future PR): check for axiswise scaling for input, weight,
# grad_output separately instead of together
if not has_any_axiswise_scaling:
if not has_any_axiswise_scaling and False:
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = _get_weight_scale(
Expand Down

0 comments on commit 945c578

Please sign in to comment.