Skip to content

Commit

Permalink
Lint fixes for kernel folder (#1487)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Jan 3, 2025
1 parent c59bce5 commit f7f20e9
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 40 deletions.
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ include = [
"torchao/profiler/**/*.py",
"torchao/testing/**/*.py",
"torchao/_models/**/*.py",
"torchao/kernel/**/*.py",
"torchao/prototype/low_bit_optim/**.py",
"torchao/utils.py",
"torchao/ops.py",
Expand Down
3 changes: 1 addition & 2 deletions torchao/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from torchao.kernel.intmm import int_scaled_matmul
from torchao.kernel.intmm import safe_int_mm
from torchao.kernel.intmm import int_scaled_matmul, safe_int_mm

__all__ = [
"safe_int_mm",
Expand Down
3 changes: 1 addition & 2 deletions torchao/kernel/autotuner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import os
import pathlib
import pickle

import torch
import triton
Expand Down Expand Up @@ -173,7 +172,7 @@ def wrapped_fn():
# Run it once and skip if it crashes or is 100x slower
try:
time = do_bench_basic(wrapped_fn, 1)
except RuntimeError as e:
except RuntimeError:
time = None
except triton.runtime.OutOfResources:
time = None
Expand Down
26 changes: 18 additions & 8 deletions torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
import os

import torch

from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_6
Expand All @@ -21,6 +21,7 @@
if TORCH_VERSION_AT_LEAST_2_2:
from torch._dynamo import is_compiling as dynamo_is_compiling
from torch._higher_order_ops.out_dtype import out_dtype

def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
Performs a safe integer matrix multiplication, considering different paths for
Expand All @@ -40,7 +41,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
if input.device.type == "cpu":
# Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend
return out_dtype(torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float())
return out_dtype(
torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float()
)
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)

# error checking for cublas path
Expand All @@ -60,9 +63,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:

if device_cpu or bad_dimensions_for_cublas:
# fallback path
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
input.device.type
)
return torch.matmul(
input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)
).to(input.device.type)

# cublas paths
if not mat2.is_contiguous(): # silently gives incorrect result without this
Expand All @@ -78,8 +81,11 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
except Exception:
# fallback path, would run on H100 for float8 dtypes
# Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn'
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(
torch.int32
)
else:

def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
Performs a fallback integer matrix multiplication for torch versions before 2.2.
Expand All @@ -93,7 +99,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
# We can improve on this by writing Triton code that works for older versions of Triton
# that ship with 2.1 or 2.0.
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(
torch.int32
)


def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
Expand All @@ -113,7 +121,9 @@ def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return safe_int_mm(a, b)


def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -> torch.Tensor:
def int_scaled_matmul(
a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor
) -> torch.Tensor:
"""
Performs scaled integer matrix multiplication.
Expand Down
50 changes: 22 additions & 28 deletions torchao/kernel/intmm_triton.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,36 @@
import itertools
import os

import torch

import triton
import triton.language as tl

from torchao.kernel.autotuner import get_best_config_fn
from torchao.utils import TORCH_VERSION_AFTER_2_5

# TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE to enable exhaustive option
int8_mm_kernel_configs = (
sum(
int8_mm_kernel_configs = sum(
[
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
[
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
[
(i, j, k, 1, 1),
(i, j, k, 1, 2),
(i, j, k, 2, 2),
(i, j, k, 1, 4),
(i, j, k, 2, 4),
(i, j, k, 3, 4),
(i, j, k, 4, 4),
(i, j, k, 1, 8),
(i, j, k, 2, 8),
(i, j, k, 3, 8),
(i, j, k, 4, 8),
(i, j, k, 5, 8),
(i, j, k, 6, 8),
(i, j, k, 7, 8),
(i, j, k, 8, 8),
]
for (i, j, k) in itertools.product(
[32, 64, 128, 256], repeat=3
)
],
[]
)
(i, j, k, 1, 1),
(i, j, k, 1, 2),
(i, j, k, 2, 2),
(i, j, k, 1, 4),
(i, j, k, 2, 4),
(i, j, k, 3, 4),
(i, j, k, 4, 4),
(i, j, k, 1, 8),
(i, j, k, 2, 8),
(i, j, k, 3, 8),
(i, j, k, 4, 8),
(i, j, k, 5, 8),
(i, j, k, 6, 8),
(i, j, k, 7, 8),
(i, j, k, 8, 8),
]
for (i, j, k) in itertools.product([32, 64, 128, 256], repeat=3)
],
[],
)

if TORCH_VERSION_AFTER_2_5:
Expand Down

0 comments on commit f7f20e9

Please sign in to comment.