forked from linkedin/Liger-Kernel
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbenchmark_rope.py
178 lines (153 loc) · 5.6 KB
/
benchmark_rope.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
from typing import List
import torch
import triton
from transformers.models.llama.modeling_llama import (
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
)
from utils import (
_print_memory_banner,
_print_speed_banner,
_test_memory,
get_current_file_directory,
)
from liger_kernel.transformers.rope import liger_rotary_pos_emb
def _get_perf_configs(target: str, ylabel: str, modes: List[str] = ["full"]):
perf_configs = []
for mode in modes:
perf_configs.append(
triton.testing.Benchmark(
x_names=["total_hidden_size"],
x_vals=[32 * (2**i) for i in range(4, 10, 2)],
line_arg="provider",
line_vals=["liger", "huggingface"],
line_names=["Liger", "Hugging Face"],
styles=[("blue", "solid"), ("orange", "solid")],
ylabel=ylabel,
plot_name=f"rope-{mode}-{target}-benchmark-seq-2048",
args={"dtype": torch.bfloat16, "mode": mode, "seq_len": 2048},
)
)
perf_configs.append(
triton.testing.Benchmark(
x_names=["seq_len"],
x_vals=[2**i for i in range(10, 15)],
line_arg="provider",
line_vals=["liger", "huggingface"],
line_names=["Liger", "Hugging Face"],
styles=[("blue", "solid"), ("orange", "solid")],
ylabel=ylabel,
plot_name=f"rope-{mode}-{target}-benchmark-total_dim_8192",
args={"dtype": torch.bfloat16, "mode": mode, "total_hidden_size": 8192},
)
)
return perf_configs
@triton.testing.perf_report(
_get_perf_configs(
target="speed", ylabel="time (ms)", modes=["forward", "backward", "full"]
)
)
def bench_speed_rope(total_hidden_size, seq_len, provider, mode, dtype):
num_q_heads = 32
num_kv_heads = 8
head_dim = total_hidden_size // num_q_heads
rotary_emb = LlamaRotaryEmbedding(head_dim, device="cuda")
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device="cuda",
requires_grad=True,
dtype=dtype,
).transpose(1, 2)
k = torch.randn(
(1, seq_len, num_kv_heads, head_dim),
device="cuda",
requires_grad=True,
dtype=dtype,
).transpose(1, 2)
dq, dk = torch.randn_like(q, device="cuda", dtype=dtype), torch.randn_like(
k, device="cuda"
)
pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0)
cos, sin = rotary_emb(k, pos_ids)
quantiles = [0.5, 0.2, 0.8]
def fwd():
if provider == "liger":
return liger_rotary_pos_emb(q, k, cos, sin, pos_ids)
elif provider == "huggingface":
return apply_rotary_pos_emb(q, k, cos, sin, pos_ids)
else:
raise ValueError(f"Invalid provider: {provider} for RoPE embedding")
if mode == "forward":
ms, min_ms, max_ms = triton.testing.do_bench(
fwd, quantiles=quantiles, grad_to_none=[q, k], rep=400
)
elif mode == "backward":
q_out, k_out = fwd()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.autograd.grad(
(q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True
),
quantiles=quantiles,
grad_to_none=[q, k],
rep=400,
)
elif mode == "full":
def full():
q_out, k_out = fwd()
torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)
ms, min_ms, max_ms = triton.testing.do_bench(
full, quantiles=quantiles, grad_to_none=[q, k], rep=400
)
return ms, max_ms, min_ms
def benchmark_speed_rope_wrapper():
_print_speed_banner()
curr_dir = get_current_file_directory()
dir_name = "rope_speed"
output_dir = os.path.join(curr_dir, dir_name)
os.makedirs(output_dir, exist_ok=True)
bench_speed_rope.run(save_path=output_dir, print_data=True)
@triton.testing.perf_report(
benchmarks=_get_perf_configs(target="memory", ylabel="GPU memory usage (MB)")
)
def bench_memory_rope(total_hidden_size, seq_len, provider, mode, dtype):
num_q_heads = 32
num_kv_heads = 8
head_dim = total_hidden_size // num_q_heads
rotary_emb = LlamaRotaryEmbedding(head_dim, device="cuda")
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device="cuda",
requires_grad=True,
dtype=dtype,
).transpose(1, 2)
k = torch.randn(
(1, seq_len, num_kv_heads, head_dim),
device="cuda",
requires_grad=True,
dtype=dtype,
).transpose(1, 2)
dq, dk = torch.randn_like(q, device="cuda", dtype=dtype), torch.randn_like(
k, device="cuda"
)
pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0)
cos, sin = rotary_emb(k, pos_ids)
def full():
if provider == "liger":
q_out, k_out = liger_rotary_pos_emb(q, k, cos, sin, pos_ids)
else:
q_out, k_out = apply_rotary_pos_emb(q, k, cos, sin, pos_ids)
torch.autograd.grad(
(q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True
)
mem = _test_memory(full)
return mem / 2**20
def benchmark_memory_rope_wrapper():
_print_memory_banner()
curr_dir = get_current_file_directory()
output_dir = os.path.join(curr_dir, "rope_memory")
os.makedirs(output_dir, exist_ok=True)
bench_memory_rope.run(save_path=output_dir, print_data=True)
if __name__ == "__main__":
benchmark_speed_rope_wrapper()
benchmark_memory_rope_wrapper()