Skip to content

Commit

Permalink
Fix DeprecationWarning
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip authored Jan 8, 2025
1 parent eb49333 commit 69cb271
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
autocast_dtype = torch.get_autocast_gpu_dtype('cuda')
input = input.to(autocast_dtype)

assert self.scaling_type_input is ScalingType.DYNAMIC
Expand Down

0 comments on commit 69cb271

Please sign in to comment.