From fb3b255eb9cef8c4d0dd6d46b226e36b373479f6 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Dec 2024 13:42:39 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- torchao/float8/float8_tensor_parallel.py | 27 +++++++++++++----------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index 37cb67c7e7..9d45196cf3 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -9,6 +9,7 @@ ) from torchao.float8.config import ScalingType, e4m3_dtype +from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_scaling_utils import ( NoopFwToFloat8BwDynamic, hp_tensor_to_float8_dynamic, @@ -46,12 +47,13 @@ def _prepare_input_fn( input_tensor, device_mesh, input_layouts, run_check=False ) - input_tensor = hp_tensor_to_float8_dynamic( - input_tensor, - mod.config.cast_config_input.target_dtype, - mod.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - ) # DTensor(Float8Tensor) + if not tensor_already_casted_to_fp8(input_tensor): + input_tensor = hp_tensor_to_float8_dynamic( + input_tensor, + mod.config.cast_config_input.target_dtype, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + ) # DTensor(Float8Tensor) # transform the input layouts to the desired layouts of ColwiseParallel if input_layouts != desired_input_layouts: @@ -104,12 +106,13 @@ def _prepare_input_fn( input_tensor, device_mesh, input_layouts, run_check=False ) - input_tensor = hp_tensor_to_float8_dynamic( - input_tensor, - mod.config.cast_config_input.target_dtype, - mod.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - ) # DTensor(Float8Tensor) + if not tensor_already_casted_to_fp8(input_tensor): + input_tensor = hp_tensor_to_float8_dynamic( + input_tensor, + mod.config.cast_config_input.target_dtype, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + ) # DTensor(Float8Tensor) if input_layouts != desired_input_layouts: input_tensor = input_tensor.redistribute(