Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
weifengpy authored Apr 28, 2024
2 parents 6f834ce + 8f67d29 commit a8a5aaa
Show file tree
Hide file tree
Showing 9 changed files with 583 additions and 15 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/nightly_smoke_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: PyTorch CUDA Nightly Smoke Test

on:
schedule:
# 6 am PST every day
- cron: "0 14 * * *"
workflow_dispatch:

concurrency:
group: regression_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
cancel-in-progress: true

env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}

jobs:
test:
strategy:
fail-fast: false
matrix:
include:
- name: CUDA Nightly
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"


uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
script: |
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
pip install -r requirements.txt
pip install -r dev-requirements.txt
python setup.py install
pytest test --verbose -s
4 changes: 2 additions & 2 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ jobs:
torch-spec: 'torch==2.3.0'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CUDA Nightly
- name: CUDA 2.4.0.dev20240421
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
torch-spec: '--pre torch==2.4.0.dev20240421+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CPU 2.2.2
Expand Down
129 changes: 129 additions & 0 deletions benchmarks/benchmark_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pandas as pd
import torch
from segment_anything import sam_model_registry
from torch.utils.benchmark import Timer
from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
_is_linear,
QuantizedLinearWeightBase,
Int8DynamicallyQuantizedLinearWeight,
)
from torchao.quantization import change_linear_weights_to_int8_dqtensors
from torchao.sparsity import (
apply_sparse_semi_structured,
apply_fake_sparsity,
)
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
from itertools import product
from tqdm import tqdm

sam_checkpoint_base_path = "/home/jessecai/local/MODELS"
model_type = 'vit_h'
model_name = 'sam_vit_h_4b8939.pth'
checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}"

torch._inductor.config.epilogue_fusion = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True

@torch.no_grad()
def benchmark(f, *args, **kwargs):
for _ in range(3):
f(*args, **kwargs)
torch.cuda.synchronize()

torch.cuda.reset_peak_memory_stats()
t0 = Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)
return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}

def get_sam_model(only_one_block=False, batchsize=1):
sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda()
model = sam.image_encoder.eval()
image = torch.randn(batchsize, 3, 1024, 1024, device='cuda')

# code to use just a single block of the model
if only_one_block:
model = model.blocks[0]
image = torch.randn(batchsize, 64, 64, 1280, device='cuda')
return model, image

def qkv_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'qkv' in name

def proj_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'proj' in name

def lin1_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin1' in name

def lin2_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin2' in name

SUBCLASSES = {
"quant" : Int8DynamicallyQuantizedLinearWeight,
"quant+sparse (cutlass)" : Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight,
"quant+sparse (cusparselt)" : Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight,
"sparse (cutlass)" : SparseSemiStructuredTensorCUTLASS,
"sparse (cusparselt)" : SparseSemiStructuredTensorCUSPARSELT,
}

def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True, qkv=None, proj=None, lin1=None, lin2=None):
res = {
"block_only": block_only,
"batchsize": batchsize,
"dtype": dtype,
"compile": compile,
"qkv" : qkv,
"proj": proj,
"lin1": lin1,
"lin2": lin2,
}
with torch.no_grad():
model, image = get_sam_model(block_only, batchsize)
model = model.to(dtype)
image = image.to(dtype)

# 2:4 prune model
apply_fake_sparsity(model)
option_and_filter_fn = zip([qkv, proj, lin1, lin2], [qkv_only, proj_only, lin1_only, lin2_only])

for option, filter_fn in option_and_filter_fn:
subclass = SUBCLASSES.get(option, None)
if subclass and issubclass(subclass, SparseSemiStructuredTensor):
# replace with to_sparse_semi_structured
for name, mod in model.named_modules():
if filter_fn(mod, name):
mod.weight = torch.nn.Parameter(subclass.from_dense(mod.weight))
elif subclass and issubclass(subclass, QuantizedLinearWeightBase):
_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(subclass), filter_fn)

if compile:
model = torch.compile(model, mode='max-autotune')

res.update(benchmark(model, image))
res["img/s"] = 1 / (res['time'] / 1000 / res['batchsize'])
return res

if __name__ == "__main__":
print("BENCHMARKING")
ALL_RUNS = [run_once(qkv="quant+sparse (cutlass)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")]
# for option in tqdm(SUBCLASSES)]
# ALL_RUNS = [
# run_once(),
# run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"),
# run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"),
# run_once(qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"),
# run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
# ]
df = pd.DataFrame(ALL_RUNS)
df.to_csv("sam_benchmark_results.csv")
print(df)
5 changes: 5 additions & 0 deletions benchmarks/sam_benchmark_results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
,block_only,batchsize,dtype,compile,qkv,proj,lin1,lin2,time,memory,img/s
0,False,32,torch.bfloat16,True,,,,,1457.0417301729321,28.280423936,21.96230851686177
1,False,32,torch.bfloat16,True,quant,quant,quant,quant,1318.5919532552361,28.261341696,24.268311300551254
2,False,32,torch.bfloat16,True,quant+sparse (cusparselt),quant,quant+sparse (cutlass),quant+sparse (cutlass),1253.1237555667758,28.18694656,25.536184960061433
3,False,32,torch.bfloat16,True,quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),1290.4946617782116,27.837008896,24.796693041648258
12 changes: 0 additions & 12 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,6 @@ def test__int_mm(self):
torch.testing.assert_close(y_ref, y_opt, atol=0, rtol=0)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)")

def test__int_mm_eager_and_torch_compile_numerics(self):
def __int_mm_ref(x, w):
x = x.cpu().to(torch.int32)
Expand Down Expand Up @@ -950,7 +948,6 @@ def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype):
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)")
def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
Expand Down Expand Up @@ -1024,8 +1021,6 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype):
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)")

def test_int8_weight_only_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype
Expand Down Expand Up @@ -1092,7 +1087,6 @@ def test_weight_only_quant(self):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)")
def test_weight_only_quant_force_mixed_mm(self, device, dtype):
if device != "cuda":
self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}")
Expand All @@ -1119,8 +1113,6 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)")

def test_weight_only_quant_use_mixed_mm(self, device, dtype):
if device != "cuda":
self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}")
Expand Down Expand Up @@ -1357,8 +1349,6 @@ class TestAutoQuant(unittest.TestCase):
# (256, 256, 128), TODO: Runs out of shared memory on T4
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)")

def test_autoquant_one_input(self, device, dtype, m, k, n):
print("(m, k, n): ", (m, k, n))
if device != "cuda" or not torch.cuda.is_available():
Expand Down Expand Up @@ -1392,8 +1382,6 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
(32, 32, 128, 128),
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)")

def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
Expand Down
71 changes: 71 additions & 0 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import logging
import unittest

import torch
from torch import nn

from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
_is_linear,
)
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torch.testing._internal.common_utils import TestCase


logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

class TestSemiStructuredSparse(TestCase):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_sparse(self):
input = torch.rand((128, 128)).half().cuda()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.half()
.cuda()
)

apply_fake_sparsity(model)
dense_result = model(input)

apply_sparse_semi_structured(model)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)


class TestQuantSemiSparse(TestCase):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quant_semi_sparse(self):
input = torch.rand((128, 128)).half().cuda()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.half()
.cuda()
)

apply_fake_sparsity(model)
dense_result = model(input)

_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)


if __name__ == "__main__":
unittest.main()
5 changes: 4 additions & 1 deletion torchao/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

from .wanda import WandaSparsifier # noqa: F403
from .utils import PerChannelNormObserver # noqa: F403
from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity

__all__ = [
"WandaSparsifier",
"PerChannelNormObserver"
"PerChannelNormObserver",
"apply_sparse_semi_structured",
"apply_fake_sparsity",
]
Loading

0 comments on commit a8a5aaa

Please sign in to comment.