Skip to content

Commit

Permalink
[GPT-fast] Support run spcific model or micro-benchmark (pytorch#143607)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang authored and pytorchmergebot committed Dec 20, 2024
1 parent 94737e8 commit 792e618
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 55 deletions.
64 changes: 23 additions & 41 deletions benchmarks/gpt_fast/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@
import json
import os

from generate import (
get_arch_name,
run_llama2_7b_bf16,
run_llama2_7b_int8,
run_mixtral_8x7b_int8,
)
from common import all_experiments, Experiment, register_experiment
from generate import get_arch_name

import torch
import torch.nn as nn
Expand All @@ -22,18 +18,6 @@
A100_40G_BF16_TFLOPS = 312


@dataclasses.dataclass
class Experiment:
name: str
metric: str
target: float
actual: float
dtype: str
device: str
arch: str # GPU name for CUDA or CPU arch for CPU
is_model: bool = False


class SimpleMLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dtype):
super().__init__()
Expand All @@ -52,6 +36,7 @@ def forward(self, x):
return x


@register_experiment(name="mlp_layer_norm_gelu")
def run_mlp_layer_norm_gelu(device: str = "cuda"):
dtype_flops_utilization_map = {
torch.bfloat16: "0.8",
Expand Down Expand Up @@ -102,6 +87,7 @@ def run_mlp_layer_norm_gelu(device: str = "cuda"):
return results


@register_experiment(name="layer_norm")
def run_layer_norm(device: str = "cuda"):
dtype_memory_bandwidth_map = {
torch.bfloat16: "950",
Expand Down Expand Up @@ -145,6 +131,7 @@ def run_layer_norm(device: str = "cuda"):
return results


@register_experiment(name="gather_gemv")
@torch._inductor.config.patch(coordinate_descent_tuning=True)
def run_gather_gemv(device: str = "cuda"):
E = 8
Expand Down Expand Up @@ -194,6 +181,7 @@ def gather_gemv(W, score_idxs, x):
return results


@register_experiment(name="gemv")
@torch._inductor.config.patch(coordinate_descent_tuning=True)
def run_gemv(device: str = "cuda"):
dtype_memory_bandwidth_map = {
Expand Down Expand Up @@ -297,30 +285,20 @@ def output_json(output_file, headers, row):

DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv"

all_experiments = {
# A list of GPT models: LlaMa, Mixtral, etc.
# waiting for A100-80G machine to be available in CI
# https://github.com/pytorch/pytorch/actions/runs/12018005803/job/33503683582?pr=140627
# before we can turn on autoquant
# or alterantively, we can save the model after autoquant and just load here to track
# the performance
# run_llama2_7b_autoquant,
run_llama2_7b_bf16,
run_llama2_7b_int8,
run_mixtral_8x7b_int8,
# run_mixtral_8x7b_autoquant,
# A list of micro-benchmarks.
run_mlp_layer_norm_gelu,
run_layer_norm,
run_gather_gemv,
run_gemv,
}


def main(output_file=DEFAULT_OUTPUT_FILE):

def main(output_file=DEFAULT_OUTPUT_FILE, only_model=None):
results = []

for func in all_experiments:
if not only_model:
experiments = all_experiments.values()
else:
if only_model not in all_experiments:
print(
f"Unknown model: {only_model}, all available models: {all_experiments.keys()}"
)
# only run the specified model
experiments = [all_experiments[only_model]]
for func in experiments:
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
except AssertionError:
Expand All @@ -347,6 +325,10 @@ def main(output_file=DEFAULT_OUTPUT_FILE):
default=DEFAULT_OUTPUT_FILE,
help="Set the output CSV file to save the benchmark results",
)
parser.add_argument(
"--only",
help="Specify a model or micro-benchmark name to run exclusively",
)
args = parser.parse_args()

main(output_file=args.output)
main(output_file=args.output, only_model=args.only)
26 changes: 26 additions & 0 deletions benchmarks/gpt_fast/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import dataclasses
from typing import Callable, Dict, Optional


all_experiments: Dict[str, Callable] = {}


@dataclasses.dataclass
class Experiment:
name: str
metric: str
target: float
actual: float
dtype: str
device: str
arch: str # GPU name for CUDA or CPU arch for CPU
is_model: bool = False


def register_experiment(name: Optional[str] = None):
def decorator(func):
key = name or func.__name__
all_experiments[key] = func
return func

return decorator
18 changes: 4 additions & 14 deletions benchmarks/gpt_fast/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, Tuple

import torchao
from common import Experiment, register_experiment
from mixtral_moe_model import ConditionalFeedForward, Transformer as MixtralMoE
from mixtral_moe_quantize import (
ConditionalFeedForwardInt8,
Expand Down Expand Up @@ -295,9 +296,8 @@ def run_experiment(


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
@register_experiment(name="llama2_7b_bf16")
def run_llama2_7b_bf16(device: str = "cuda"):
from benchmark import Experiment

model = GPTModelConfig(
"Llama-2-7b-chat-hf",
LLaMA,
Expand Down Expand Up @@ -345,9 +345,8 @@ def run_llama2_7b_bf16(device: str = "cuda"):


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
@register_experiment(name="llama2_7b_int8")
def run_llama2_7b_int8(device: str = "cuda"):
from benchmark import Experiment

model = GPTModelConfig(
"Llama-2-7b-chat-hf",
LLaMA,
Expand Down Expand Up @@ -395,9 +394,8 @@ def run_llama2_7b_int8(device: str = "cuda"):


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
@register_experiment(name="mixtral_8x7b_int8")
def run_mixtral_8x7b_int8(device: str = "cuda"):
from benchmark import Experiment

# We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
model = GPTModelConfig(
"Mixtral-8x7B-v0.1",
Expand Down Expand Up @@ -447,8 +445,6 @@ def run_mixtral_8x7b_int8(device: str = "cuda"):

# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_llama2_7b_autoquant(device: str = "cuda"):
from benchmark import Experiment

model = GPTModelConfig(
"Llama-2-7b-chat-hf",
LLaMA,
Expand Down Expand Up @@ -497,8 +493,6 @@ def run_llama2_7b_autoquant(device: str = "cuda"):

# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_mixtral_8x7b_autoquant(device: str = "cuda"):
from benchmark import Experiment

# We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
model = GPTModelConfig(
"Mixtral-8x7B-v0.1",
Expand Down Expand Up @@ -548,8 +542,6 @@ def run_mixtral_8x7b_autoquant(device: str = "cuda"):

# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_llama2_7b_autoquant_v2(device: str = "cuda"):
from benchmark import Experiment

model = GPTModelConfig(
"Llama-2-7b-chat-hf",
LLaMA,
Expand Down Expand Up @@ -599,8 +591,6 @@ def run_llama2_7b_autoquant_v2(device: str = "cuda"):

# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_mixtral_8x7b_autoquant_v2(device: str = "cuda"):
from benchmark import Experiment

# We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
model = GPTModelConfig(
"Mixtral-8x7B-v0.1",
Expand Down

0 comments on commit 792e618

Please sign in to comment.