Skip to content

Commit

Permalink
Updated for LLM shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Nov 13, 2024
1 parent 7b904c1 commit e56abbc
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 39 deletions.
115 changes: 76 additions & 39 deletions benchmarks/float8/float8_inference_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
import torchao
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.profiler import profile, record_function, ProfilerActivity
from torchao.quantization.quant_api import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
import copy
from utils import (
get_name_to_shapes_iter,
get_llm_mm_shapes,
get_diffusion_mm_shapes
)
import tqdm
from tabulate import tabulate
Expand Down Expand Up @@ -79,24 +81,14 @@ def get_gpu_kernel_times(profiler_chrome_trace, gpu_op_name):
gpu_overhead_time += event[1]
return gpu_op_time, gpu_overhead_time

def run_gemm_benchmarks(name_to_shapes, float8_dtype=torch.float8_e4m3fn, other_dtype=torch.bfloat16, quantization_technique=float8_weight_only):
# Dictionary to store performance data
performance_data = {
'Input Size': [],
'float8 Op Kernel Times (ms)': [],
'bf16 Op Kernel Times (ms)': [],
'float8 Overhead Kernel Times (ms)': [],
'bf16 Overhead Kernel Times (ms)': [],
'float8 Total Kernel Times (ms)': [],
'bf16 Total Kernel Times (ms)': [],
}
def run_gemm_benchmarks(performance_data, name_to_shapes, float8_dtype=torch.float8_e4m3fn, other_dtype=torch.bfloat16, quantization_technique=float8_weight_only, batch_size=1):
# Run benchmarks for each input size
for idx, (name, (m, k, n)) in enumerate(tqdm.tqdm(name_to_shapes)):
print(f"Profiling model with input size: {m, k, n} for quantization technique: {quantization_technique}, dtype: {float8_dtype} vs {other_dtype}")
print(f"Profiling model with input size: {batch_size, m, k, n} for quantization technique: {quantization_technique}, dtype: {float8_dtype} vs {other_dtype}")

# Initialize the model with the specified dimensions
model = ToyLinearModel(m, k, n).eval().to(device)
example_inputs = model.example_inputs(m)
model = ToyLinearModel(batch_size*m, k, n).eval().to(device)
example_inputs = model.example_inputs(batch_size*m)
model_bf16 = copy.deepcopy(model).to(device) # Copy the model to bf
model_bf16 = torch.compile(model_bf16) # Compile the model

Expand All @@ -117,51 +109,96 @@ def run_gemm_benchmarks(name_to_shapes, float8_dtype=torch.float8_e4m3fn, other_
float8_gpu_op_time, float8_gpu_overhead_time = get_gpu_kernel_times(prof_float8, 'gemm')
bf16_gpu_op_time, bf16_gpu_overhead_time = get_gpu_kernel_times(prof_bf16, 'gemm')

# # Print profiling details
# print(f"bfloat16_gpu_overhead_time: {bf16_gpu_overhead_time} gpu_op_time: {bf16_gpu_op_time}")
# print(f"float8_gpu_overhead_time: {float8_gpu_overhead_time} float8_gpu_op_time: {float8_gpu_op_time}")

# Add the performance data to the dictionary
# time/1000 -> Convert from microseconds to milliseconds
performance_data['Input Size'].append(f"{tuple(example_inputs[0].shape)}")
performance_data['Input Size'].append(f"{(m, k, n)}")
performance_data['float8 Total Kernel Times (ms)'].append((float8_gpu_op_time + float8_gpu_overhead_time) / 1000)
performance_data['bf16 Total Kernel Times (ms)'].append((bf16_gpu_op_time + bf16_gpu_overhead_time) / 1000)
performance_data['float8 Op Kernel Times (ms)'].append(float8_gpu_op_time / 1000)
performance_data['bf16 Op Kernel Times (ms)'].append(bf16_gpu_op_time / 1000)
performance_data['float8 Overhead Kernel Times (ms)'].append(float8_gpu_overhead_time / 1000)
performance_data['bf16 Overhead Kernel Times (ms)'].append(bf16_gpu_overhead_time / 1000)
performance_data['Batch Size'].append(batch_size)

return performance_data


def plot_performance_data(performance_data):
def plot_performance_data(performance_data, x_col, plot_name = 'model_evaluation_gpu_kernel_performance'):
# Plotting the results
plt.figure(figsize=(10, 6))
plt.plot(performance_data['Input Size'], performance_data['float8 Total Kernel Times (ms)'], marker='o', label='float8')
plt.plot(performance_data['Input Size'], performance_data['bf16 Total Kernel Times (ms)'], marker='s', label='bf16')
plt.xlabel('Batch Size')
plt.plot(performance_data[x_col], performance_data['float8 Total Kernel Times (ms)'], marker='o', label='float8')
plt.plot(performance_data[x_col], performance_data['bf16 Total Kernel Times (ms)'], marker='s', label='bf16')
plt.xlabel(x_col)
plt.ylabel('Kernel Time (ms)')
plt.title('Model Evaluation GPU Kernel Performance: float8 vs bf16')
plt.title(plot_name+' performance: float8 vs bf16')
plt.legend()
plt.grid(True)
plt.savefig('model_evaluation_gpu_kernel_performance.png')
plt.savefig(plot_name+'.png')


if __name__ == '__main__':

# llm_model_names = ["bert-base-uncased", "gpt2", "t5-small", "meta-llama/Llama-3.2-3B", "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"]
# name_to_shapes = get_name_to_shapes_iter("llama", None, None, None)
name_to_shapes = get_llm_mm_shapes("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", None, None, None)

print('Shapes:', name_to_shapes)
float8_dtype = torch.float8_e4m3fn # Change to the float8 dtype you want to use
bf16_dtype = torch.bfloat16 # Change to the comparing dtype you want to use
def plot_llm_performance_data_hf_model(model_name, quantization_dtype=torch.float8_e4m3fn, quantization_technique=float8_weight_only, baseline_dtype=torch.bfloat16, batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096,]):
# Dictionary to store performance data
performance_data = {
'Input Size': [],
'float8 Op Kernel Times (ms)': [],
'bf16 Op Kernel Times (ms)': [],
'float8 Overhead Kernel Times (ms)': [],
'bf16 Overhead Kernel Times (ms)': [],
'float8 Total Kernel Times (ms)': [],
'bf16 Total Kernel Times (ms)': [],
'Batch Size': []
}
name_to_shapes = get_llm_mm_shapes(model_name, seq_len=128)
print(f'For Model: {model_name}, and Shapes: {name_to_shapes}')
quantization_technique = float8_weight_only # Change to the quantization technique you want to use
for batch_size in batch_sizes:
performance_data = run_gemm_benchmarks(
performance_data=performance_data,
name_to_shapes=name_to_shapes,
float8_dtype=quantization_dtype,
other_dtype=baseline_dtype,
quantization_technique=quantization_technique,
batch_size=batch_size,
)
df_performance_data = pd.DataFrame(performance_data)
df_grouped = df_performance_data.groupby('Input Size')
for name, group in df_grouped:
print(f"Group: {name}")
# print(group)
plot_performance_data(group, 'Batch Size', plot_name=f'{model_name.split("/")[-1]}_input_size_{name}_quant_{quantization_technique}')


if __name__ == '__main__':

# Run benchmarks for LLMs
llm_model_names = ["bert-base-uncased", "gpt2", "t5-small", "meta-llama/Llama-3.2-3B", "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"]
for model_name in llm_model_names:
plot_llm_performance_data_hf_model(
model_name,
quantization_dtype=torch.float8_e4m3fn,
quantization_technique=float8_weight_only,
baseline_dtype=torch.bfloat16,
)

# Run benchmarks for different_matrix_shapes (m, k, n)
name_to_shapes = get_name_to_shapes_iter("square", None, None, None)
# Dictionary to store performance data
performance_data = {
'Input Size': [],
'float8 Op Kernel Times (ms)': [],
'bf16 Op Kernel Times (ms)': [],
'float8 Overhead Kernel Times (ms)': [],
'bf16 Overhead Kernel Times (ms)': [],
'float8 Total Kernel Times (ms)': [],
'bf16 Total Kernel Times (ms)': [],
'Batch Size': []
}
performance_data = run_gemm_benchmarks(
performance_data=performance_data,
name_to_shapes=name_to_shapes,
float8_dtype=float8_dtype,
other_dtype=bf16_dtype,
quantization_technique=quantization_technique
)
print('Performance data: \n', tabulate(performance_data, headers=performance_data.keys()))
float8_dtype=torch.float8_e4m3fn,
other_dtype=torch.bfloat16,
quantization_technique=float8_weight_only,
)
plot_performance_data(performance_data, 'Input Size', plot_name='different_matrix_shapes')
# print('Performance data: \n', tabulate(performance_data, headers=performance_data.keys()))
22 changes: 22 additions & 0 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Optional

from torch.profiler import profile, ProfilerActivity, record_function
from transformers import AutoConfig
from diffusers import DiffusionPipeline

def profiler_output_to_filtered_time_by_kernel_name(
prof,
Expand Down Expand Up @@ -202,6 +204,7 @@ def get_name_to_shapes_iter(

raise AssertionError(f'unknown shape_gen_name {shape_gen_name}')


# copy-pasta from https://github.com/vkuzo/pytorch_scripts/blob/main/add_inductor_metadata_to_perf_trace.py
def update_triton_kernels_in_prof_chome_trace_with_torch_logs(
perf_trace_file: str,
Expand Down Expand Up @@ -348,3 +351,22 @@ def get_gpu_kernel_gemm_time_s(f, *args, **kwargs):
return data["aten::_scaled_mm"] / 1e6 / n_iter
else:
raise AssertionError("unexpected format of data")


def get_llm_mm_shapes(model_name, seq_len=512):
"""Extracts matrix shapes for matrix multiplications in attention and feed-forward layers for an LLM model."""
config = AutoConfig.from_pretrained(model_name)

hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
intermediate_size = getattr(config, "intermediate_size", hidden_size * 4) # Typically 4x hidden size

d_head = hidden_size // num_attention_heads

matrix_shapes = {
"Attention mm": (seq_len, seq_len, d_head), # Attention score matrix per head
"Input -> Intermediate": (seq_len, hidden_size, intermediate_size), # Feed-forward layer matrix multiplication shapes
"Intermediate -> Output": (seq_len, intermediate_size, hidden_size), # Feed-forward layer matrix multiplication shapes
}

return matrix_shapes.items()

0 comments on commit e56abbc

Please sign in to comment.