Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Dec 23, 2024
1 parent eab345c commit 09821f0
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
17 changes: 10 additions & 7 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,16 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
if tensor_already_casted_to_fp8(input):
input_fp8 = input
else:
assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
return input_fp8

def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
Expand Down
2 changes: 0 additions & 2 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def hp_tensor_to_float8_dynamic(
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
"""
if tensor_already_casted_to_fp8(hp_tensor):
return hp_tensor
scale = tensor_to_scale(
hp_tensor,
float8_dtype,
Expand Down
4 changes: 3 additions & 1 deletion torchao/float8/stateful_float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

if self.scaling_type_input is ScalingType.DELAYED:
if tensor_already_casted_to_fp8(input):
input_fp8 = input
elif self.scaling_type_input is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
input,
Expand Down

0 comments on commit 09821f0

Please sign in to comment.