diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 18aebaeada..000538986e 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -36,25 +36,12 @@ def _get_weight_scale( if tensor_already_casted_to_fp8(weight): return None assert scaling_type_weight is ScalingType.DYNAMIC - return tensor_to_scale(weight, config.cast_config_weight.target_dtype) - - -def _cast_weight_to_float8_t( - weight: torch.Tensor, - config: Float8LinearConfig, - linear_mm_config: LinearMMConfig, - weight_scale: Optional[torch.Tensor] = None, -) -> torch.Tensor: - if tensor_already_casted_to_fp8(weight): - return weight.t() - weight_fp8 = hp_tensor_and_scale_to_float8( - weight, - weight_scale, - config.cast_config_weight.target_dtype, - linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) - return weight_fp8.t() + # inductor kernels for tensorwise max are faster when `weight` is + # contiguous. + # context: https://github.com/pytorch/pytorch/issues/144431 + weight_t = weight.t() + assert weight_t.is_contiguous() + return tensor_to_scale(weight_t, config.cast_config_weight.target_dtype) @torch._dynamo.allow_in_graph @@ -102,6 +89,33 @@ def forward( weight_maybe_fp8_t = weight_hp_t elif c.cast_config_weight.scaling_type is ScalingType.DISABLED: weight_maybe_fp8_t = weight_hp_t + elif ( + config.cast_config_weight.scaling_granularity + is ScalingGranularity.TENSORWISE + ): + # Special case tensorwise scaling to allow the checkpointing of + # float8 casted weight, to prevent blowing up peak memory usage + # in FSDP. + weight_scale = _get_weight_scale( + weight_hp_t, config.cast_config_weight.scaling_type, config + ) + if config.force_recompute_fp8_weight_in_bwd: + weight_maybe_fp8_t = checkpoint.checkpoint( + hp_tensor_and_scale_to_float8, + weight_hp_t, + weight_scale, + config.cast_config_weight.target_dtype, + linear_mm_config, + GemmInputRole.WEIGHT, + ) + else: + weight_maybe_fp8_t = hp_tensor_and_scale_to_float8( + weight_hp_t, + weight_scale, + config.cast_config_weight.target_dtype, + linear_mm_config, + GemmInputRole.WEIGHT, + ) else: weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( weight_hp_t, @@ -294,50 +308,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: autocast_dtype = torch.get_autocast_gpu_dtype() input = input.to(autocast_dtype) - has_any_axiswise_scaling = any( - cc.scaling_granularity is ScalingGranularity.AXISWISE - for cc in [ - self.config.cast_config_input, - self.config.cast_config_weight, - self.config.cast_config_grad_output, - self.config.cast_config_input_for_grad_weight, - self.config.cast_config_weight_for_grad_input, - self.config.cast_config_grad_output_for_grad_weight, - ] - ) - - weight_maybe_fp8_t = self.weight.t() - - # TODO(future PR): check for axiswise scaling for input, weight, - # grad_output separately instead of together - if not has_any_axiswise_scaling: - # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, - # weight_scale should be saved. - weight_scale = _get_weight_scale( - self.weight, self.scaling_type_weight, self.config - ) - - if self.config.force_recompute_fp8_weight_in_bwd: - weight_fp8_t = checkpoint.checkpoint( - _cast_weight_to_float8_t, - self.weight, - self.config, - self.linear_mm_config, - weight_scale, - ) - else: - weight_fp8_t = _cast_weight_to_float8_t( - self.weight, - self.config, - self.linear_mm_config, - weight_scale, - ) - - weight_maybe_fp8_t = weight_fp8_t - output = matmul_with_hp_or_float8_args.apply( input, - weight_maybe_fp8_t, + self.weight.t(), self.linear_mm_config, self.config, )