Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Dec 23, 2024
1 parent 09821f0 commit fb3b255
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions torchao/float8/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)

from torchao.float8.config import ScalingType, e4m3_dtype
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_scaling_utils import (
NoopFwToFloat8BwDynamic,
hp_tensor_to_float8_dynamic,
Expand Down Expand Up @@ -46,12 +47,13 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
mod.config.cast_config_input.target_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
if not tensor_already_casted_to_fp8(input_tensor):
input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
mod.config.cast_config_input.target_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)

# transform the input layouts to the desired layouts of ColwiseParallel
if input_layouts != desired_input_layouts:
Expand Down Expand Up @@ -104,12 +106,13 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
mod.config.cast_config_input.target_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
if not tensor_already_casted_to_fp8(input_tensor):
input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
mod.config.cast_config_input.target_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)

if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(
Expand Down

0 comments on commit fb3b255

Please sign in to comment.