diff --git a/README.md b/README.md index fa093c3..ff19b93 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ m = Model(...) # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling # type swap_linear_with_float8_linear( - m, + m, Float8Linear, scaling_type_x=TensorScalingType.DELAYED, scaling_type_w=TensorScalingType.DELAYED, diff --git a/benchmarks/bench_multi_gpu.py b/benchmarks/bench_multi_gpu.py index 12a1ddb..00a549c 100644 --- a/benchmarks/bench_multi_gpu.py +++ b/benchmarks/bench_multi_gpu.py @@ -14,7 +14,7 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.utils.benchmark as benchmark -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index 1ef5478..503a01a 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, swap_linear_with_float8_linear, diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index f48424c..7f44363 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -9,10 +9,7 @@ from typing import Any, Optional, Tuple -import float8_experimental.config as config - import torch -import torch.nn as nn import torch.utils._pytree as pytree from float8_experimental.float8_tensor import ( diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 945f7a6..5d49e65 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -3,10 +3,8 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import copy import logging -from enum import auto, Enum -from typing import Callable, List, Optional, Type, Union +from typing import Callable, List, Optional import torch import torch.distributed as dist diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 2088b78..8aada4b 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -15,7 +15,7 @@ import torch.nn.functional as F from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig from float8_experimental.float8_tensor_parallel import ( diff --git a/test/test_fsdp.py b/test/test_fsdp.py index 79bba19..48b28da 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -21,7 +21,7 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, swap_linear_with_float8_linear, @@ -149,7 +149,7 @@ def forward_backward(model, optim, is_fp8, i): model_fp8 = torch.compile(model_fp8) y_local = forward_backward(model, optimizer, is_fp8=False, i=i) y_local_fp8 = forward_backward(model_fp8, optimizer_fp8, is_fp8=True, i=i) - local_sqnr = compute_error(y_local, y_local_fp8) + local_sqnr = compute_error(y_local, y_local_fp8) # noqa: F841 # get global y y_global = [ diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index c20e8cc..9d42b56 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -1,12 +1,11 @@ import contextlib -from typing import List, Type +from typing import List import float8_experimental.config as config import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_linear import Float8Linear def check_parity_no_mp( diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 57123cd..5ca483f 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -1,5 +1,4 @@ import copy -import itertools import threading import unittest from typing import Any, List @@ -9,7 +8,7 @@ import torch.distributed as dist import torch.nn as nn from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from test_fsdp2_common import ( check_parity_bf16_mp, diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index 715db29..3f1b5dc 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -18,7 +18,7 @@ import torch.multiprocessing as mp import torch.nn as nn from float8_experimental import config -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, diff --git a/test/test_inference_flows.py b/test/test_inference_flows.py index 1dd09d9..55543ae 100644 --- a/test/test_inference_flows.py +++ b/test/test_inference_flows.py @@ -13,7 +13,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import compute_error diff --git a/test/test_numerics_integration.py b/test/test_numerics_integration.py index 401d0fd..845c9ea 100644 --- a/test/test_numerics_integration.py +++ b/test/test_numerics_integration.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, swap_linear_with_float8_linear,