forked from linkedin/Liger-Kernel
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbenchmark_rms_norm.py
181 lines (146 loc) · 5.49 KB
/
benchmark_rms_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import torch
import torch.nn as nn
import triton
from utils import _print_memory_banner, _print_speed_banner, _test_memory
from liger_kernel.transformers.rms_norm import LigerRMSNorm
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
@triton.testing.perf_report(
[
triton.testing.Benchmark(
x_names=["N"],
x_vals=[2**i for i in range(10, 16)],
xlabel="hidden size",
line_arg="provider",
line_vals=["liger", "huggingface"],
line_names=["Liger", "Hugging Face"],
styles=[("blue", "solid"), ("orange", "solid")],
ylabel="time (ms)",
plot_name="rmsnorm-fwd-speed-benchmark",
args={"M": 2048, "dtype": torch.bfloat16, "mode": "forward"},
),
triton.testing.Benchmark(
x_names=["N"],
x_vals=[2**i for i in range(10, 16)],
xlabel="hidden size",
line_arg="provider",
line_vals=["liger", "huggingface"],
line_names=["Liger", "Hugging Face"],
styles=[("blue", "solid"), ("orange", "solid")],
ylabel="time (ms)",
plot_name="rmsnorm-bwd-speed-benchmark",
args={"M": 2048, "dtype": torch.bfloat16, "mode": "backward"},
),
triton.testing.Benchmark(
x_names=["N"],
x_vals=[2**i for i in range(10, 16)],
xlabel="hidden size",
line_arg="provider",
line_vals=["liger", "huggingface"],
line_names=["Liger", "Hugging Face"],
styles=[("blue", "solid"), ("orange", "solid")],
ylabel="time (ms)",
plot_name="rmsnorm-full-speed-benchmark",
args={"M": 2048, "dtype": torch.bfloat16, "mode": "full"},
),
]
)
def bench_speed_rms_norm(M, N, dtype, provider, mode, eps=1e-5, device="cuda"):
x_shape = (M, N)
triton_rms = LigerRMSNorm(hidden_size=N).to("cuda")
llama_rms = LlamaRMSNorm(hidden_size=N).to("cuda")
x = torch.randn(x_shape, dtype=dtype, device="cuda")
dy = torch.randn_like(x)
x.requires_grad_(True)
quantiles = [0.5, 0.2, 0.8]
# utility functions
def y_fwd():
if provider == "liger":
return triton_rms(x)
if provider == "huggingface":
return llama_rms(x)
if mode == "forward":
ms, min_ms, max_ms = triton.testing.do_bench(
y_fwd, quantiles=quantiles, grad_to_none=[x], rep=500
)
elif mode == "backward":
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
quantiles=quantiles,
grad_to_none=[x],
rep=500,
)
elif mode == "full":
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
ms, min_ms, max_ms = triton.testing.do_bench(
full, quantiles=quantiles, grad_to_none=[x], rep=500
)
return ms, max_ms, min_ms
def benchmark_speed_rms_norm_wrapper():
_print_speed_banner()
curr_dir = os.path.dirname(os.path.abspath(__file__))
dir_name = "rms_norm_speed"
output_dir = os.path.join(curr_dir, dir_name)
os.makedirs(output_dir, exist_ok=True)
bench_speed_rms_norm.run(save_path=output_dir, print_data=True)
@triton.testing.perf_report(
[
triton.testing.Benchmark(
x_names=["N"],
x_vals=[2**i for i in range(10, 16)],
xlabel="hidden size",
line_arg="provider",
line_vals=["liger", "huggingface"],
line_names=["Liger", "Hugging Face"],
styles=[("blue", "solid"), ("orange", "solid")],
ylabel="GPU memory usage (MB)",
plot_name="rmsnorm-full-memory-benchmark",
args={"M": 2048, "dtype": torch.bfloat16, "mode": "full"},
)
]
)
def bench_memory_rms_norm(M, N, dtype, provider, mode, eps=1e-5, device="cuda"):
x_shape = (M, N)
triton_rms = LigerRMSNorm(hidden_size=N).to("cuda")
llama_rms = LlamaRMSNorm(hidden_size=N).to("cuda")
x = torch.randn(x_shape, dtype=dtype, device="cuda")
dy = torch.randn_like(x)
x.requires_grad_(True)
# utility functions
def y_fwd():
if provider == "liger":
return triton_rms(x)
if provider == "huggingface":
return llama_rms(x)
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
mem = _test_memory(full)
return mem / 2**20
def benchmark_memory_rms_norm_wrapper():
_print_memory_banner()
curr_dir = os.path.dirname(os.path.abspath(__file__))
dir_name = "rms_norm_memory"
output_dir = os.path.join(curr_dir, dir_name)
os.makedirs(output_dir, exist_ok=True)
# TODO: make precision configurable in generated csv
bench_memory_rms_norm.run(save_path=output_dir, print_data=True)
if __name__ == "__main__":
benchmark_speed_rms_norm_wrapper()
benchmark_memory_rms_norm_wrapper()