From 945c578f958d7eae62f2ec4094e8639d18813aeb Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 8 Jan 2025 14:26:13 -0800 Subject: [PATCH] [cleanup][4/x] unify weight casting Summary: Not ready for review yet, performance regression because tensorwise abs+max and weight casting is happening twice between fwd and bwd. Limitation of something in PT2 stack? Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 27996f8401a77ca2fc5fdf1bb2b200d3b9fd41a7 ghstack-comment-id: 2568319095 Pull Request resolved: https://github.com/pytorch/ao/pull/1481 --- torchao/float8/float8_linear.py | 62 +++++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 15 deletions(-) diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 18aebaeada..f8a5381fae 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -36,17 +36,22 @@ def _get_weight_scale( if tensor_already_casted_to_fp8(weight): return None assert scaling_type_weight is ScalingType.DYNAMIC - return tensor_to_scale(weight, config.cast_config_weight.target_dtype) + # inductor kernels for tensorwise max are faster when `weight` is + # contiguous. + # context: https://github.com/pytorch/pytorch/issues/144431 + weight_t = weight.t() + assert weight_t.is_contiguous() + return tensor_to_scale(weight_t, config.cast_config_weight.target_dtype) -def _cast_weight_to_float8_t( +def _cast_weight_to_float8( weight: torch.Tensor, config: Float8LinearConfig, linear_mm_config: LinearMMConfig, weight_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if tensor_already_casted_to_fp8(weight): - return weight.t() + return weight weight_fp8 = hp_tensor_and_scale_to_float8( weight, weight_scale, @@ -54,7 +59,7 @@ def _cast_weight_to_float8_t( linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ) - return weight_fp8.t() + return weight_fp8 @torch._dynamo.allow_in_graph @@ -103,16 +108,43 @@ def forward( elif c.cast_config_weight.scaling_type is ScalingType.DISABLED: weight_maybe_fp8_t = weight_hp_t else: - weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( - weight_hp_t, - c.cast_config_weight.target_dtype, - linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=c.cast_config_weight.scaling_granularity, - axiswise_dim=get_maybe_axiswise_dim( - 0, c.cast_config_weight.scaling_granularity - ), - ) + # non-axiswise + if ( + config.cast_config_weight.scaling_granularity + is ScalingGranularity.TENSORWISE + ): + # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, + # weight_scale should be saved. + weight_scale = _get_weight_scale( + weight_hp_t, config.cast_config_weight.scaling_type, config + ) + + if config.force_recompute_fp8_weight_in_bwd: + weight_maybe_fp8_t = checkpoint.checkpoint( + _cast_weight_to_float8, + weight_hp_t, + config, + linear_mm_config, + weight_scale, + ) + else: + weight_maybe_fp8_t = _cast_weight_to_float8( + weight_hp_t, + config, + linear_mm_config, + weight_scale, + ) + else: + weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( + weight_hp_t, + c.cast_config_weight.target_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=c.cast_config_weight.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim( + 0, c.cast_config_weight.scaling_granularity + ), + ) # the reshapes are needed in order to make the shapes compatible with # torch.mm @@ -310,7 +342,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # TODO(future PR): check for axiswise scaling for input, weight, # grad_output separately instead of together - if not has_any_axiswise_scaling: + if not has_any_axiswise_scaling and False: # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, # weight_scale should be saved. weight_scale = _get_weight_scale(