From 69cb271be683776d6c6e067babf3585bd6dfae9e Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 8 Jan 2025 00:21:58 -0500 Subject: [PATCH] Fix DeprecationWarning --- torchao/float8/float8_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index b7a3449277..c492aece38 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -309,7 +309,7 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor: if torch.is_autocast_enabled(): # For now, hardcode to GPU's autocast dtype # if we need CPU support in the future, we can add it - autocast_dtype = torch.get_autocast_gpu_dtype() + autocast_dtype = torch.get_autocast_gpu_dtype('cuda') input = input.to(autocast_dtype) assert self.scaling_type_input is ScalingType.DYNAMIC