From f7f20e9e8ba28a4910efb62e7b18243143d8d855 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 3 Jan 2025 13:42:42 -0800 Subject: [PATCH] Lint fixes for kernel folder (#1487) --- ruff.toml | 1 + torchao/kernel/__init__.py | 3 +- torchao/kernel/autotuner.py | 3 +- torchao/kernel/intmm.py | 26 ++++++++++++------ torchao/kernel/intmm_triton.py | 50 +++++++++++++++------------------- 5 files changed, 43 insertions(+), 40 deletions(-) diff --git a/ruff.toml b/ruff.toml index 4c3e00c9e6..cef8b0d076 100644 --- a/ruff.toml +++ b/ruff.toml @@ -10,6 +10,7 @@ include = [ "torchao/profiler/**/*.py", "torchao/testing/**/*.py", "torchao/_models/**/*.py", + "torchao/kernel/**/*.py", "torchao/prototype/low_bit_optim/**.py", "torchao/utils.py", "torchao/ops.py", diff --git a/torchao/kernel/__init__.py b/torchao/kernel/__init__.py index 3eeba6f535..409da72601 100644 --- a/torchao/kernel/__init__.py +++ b/torchao/kernel/__init__.py @@ -1,5 +1,4 @@ -from torchao.kernel.intmm import int_scaled_matmul -from torchao.kernel.intmm import safe_int_mm +from torchao.kernel.intmm import int_scaled_matmul, safe_int_mm __all__ = [ "safe_int_mm", diff --git a/torchao/kernel/autotuner.py b/torchao/kernel/autotuner.py index ac6ab6feb0..87d644b26c 100644 --- a/torchao/kernel/autotuner.py +++ b/torchao/kernel/autotuner.py @@ -1,7 +1,6 @@ import logging import os import pathlib -import pickle import torch import triton @@ -173,7 +172,7 @@ def wrapped_fn(): # Run it once and skip if it crashes or is 100x slower try: time = do_bench_basic(wrapped_fn, 1) - except RuntimeError as e: + except RuntimeError: time = None except triton.runtime.OutOfResources: time = None diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index afc5bcfa3f..3079cf0104 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -1,5 +1,5 @@ -import itertools import os + import torch from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_6 @@ -21,6 +21,7 @@ if TORCH_VERSION_AT_LEAST_2_2: from torch._dynamo import is_compiling as dynamo_is_compiling from torch._higher_order_ops.out_dtype import out_dtype + def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: """ Performs a safe integer matrix multiplication, considering different paths for @@ -40,7 +41,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): if input.device.type == "cpu": # Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend - return out_dtype(torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float()) + return out_dtype( + torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float() + ) return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) # error checking for cublas path @@ -60,9 +63,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: if device_cpu or bad_dimensions_for_cublas: # fallback path - return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( - input.device.type - ) + return torch.matmul( + input.cpu().to(torch.int32), mat2.cpu().to(torch.int32) + ).to(input.device.type) # cublas paths if not mat2.is_contiguous(): # silently gives incorrect result without this @@ -78,8 +81,11 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: except Exception: # fallback path, would run on H100 for float8 dtypes # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' - return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) + return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to( + torch.int32 + ) else: + def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: """ Performs a fallback integer matrix multiplication for torch versions before 2.2. @@ -93,7 +99,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: """ # We can improve on this by writing Triton code that works for older versions of Triton # that ship with 2.1 or 2.0. - return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) + return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to( + torch.int32 + ) def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: @@ -113,7 +121,9 @@ def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return safe_int_mm(a, b) -def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -> torch.Tensor: +def int_scaled_matmul( + a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor +) -> torch.Tensor: """ Performs scaled integer matrix multiplication. diff --git a/torchao/kernel/intmm_triton.py b/torchao/kernel/intmm_triton.py index f6f42e2f53..bb73a0f5db 100644 --- a/torchao/kernel/intmm_triton.py +++ b/torchao/kernel/intmm_triton.py @@ -1,8 +1,6 @@ import itertools -import os import torch - import triton import triton.language as tl @@ -10,33 +8,29 @@ from torchao.utils import TORCH_VERSION_AFTER_2_5 # TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE to enable exhaustive option -int8_mm_kernel_configs = ( - sum( +int8_mm_kernel_configs = sum( + [ + # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" [ - # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" - [ - (i, j, k, 1, 1), - (i, j, k, 1, 2), - (i, j, k, 2, 2), - (i, j, k, 1, 4), - (i, j, k, 2, 4), - (i, j, k, 3, 4), - (i, j, k, 4, 4), - (i, j, k, 1, 8), - (i, j, k, 2, 8), - (i, j, k, 3, 8), - (i, j, k, 4, 8), - (i, j, k, 5, 8), - (i, j, k, 6, 8), - (i, j, k, 7, 8), - (i, j, k, 8, 8), - ] - for (i, j, k) in itertools.product( - [32, 64, 128, 256], repeat=3 - ) - ], - [] - ) + (i, j, k, 1, 1), + (i, j, k, 1, 2), + (i, j, k, 2, 2), + (i, j, k, 1, 4), + (i, j, k, 2, 4), + (i, j, k, 3, 4), + (i, j, k, 4, 4), + (i, j, k, 1, 8), + (i, j, k, 2, 8), + (i, j, k, 3, 8), + (i, j, k, 4, 8), + (i, j, k, 5, 8), + (i, j, k, 6, 8), + (i, j, k, 7, 8), + (i, j, k, 8, 8), + ] + for (i, j, k) in itertools.product([32, 64, 128, 256], repeat=3) + ], + [], ) if TORCH_VERSION_AFTER_2_5: