-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathevaluater.py
161 lines (152 loc) · 7.18 KB
/
evaluater.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
import numpy as np
from tqdm import tqdm
import time
import itertools
from utils.data_utils import get_test_data
import os
import sys
current_path = os.path.dirname(os.path.abspath(__file__))
parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(current_path)
@torch.no_grad()
def ppl_eval(model, tokenizer, datasets=['wikitext2', 'ptb', 'c4'], model_seq_len=2048, batch_size=32, device="cuda"):
model.to(device)
model.eval()
ppls = {}
for dataset in datasets:
test_loader = get_test_data(dataset, tokenizer, seq_len=model_seq_len, batch_size = batch_size)
nlls = []
for batch in tqdm(test_loader):
batch = batch.to(device)
output = model(batch, use_cache=False)
lm_logits = output.logits
if torch.isfinite(lm_logits).all():
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = batch[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
nlls.append(loss)
ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
ppls[dataset] = ppl
print("PPL after pruning: {}".format(ppls))
print("Weight Memory: {} MiB\n".format(torch.cuda.memory_allocated()/1024/1024))
# only call this function when for 65b or more model
@torch.no_grad()
def ppl_eval_large(model, tokenizer, datasets=['wikitext2', 'ptb', 'c4'], seq_len=2048, batch_size=32, device="cuda"):
import torch.nn as nn
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size=model.config.hidden_size, eps=model.config.rms_norm_eps):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
norm = LlamaRMSNorm().half().cuda()
lm_head = model.lm_head.cuda()
model.eval()
ppls = {}
layers = model.model.layers
for dataset in datasets:
test_loader = get_test_data(dataset, tokenizer, seq_len=seq_len, batch_size = batch_size)
nlls = []
for batch in tqdm(test_loader):
model.model.embed_tokens = model.model.embed_tokens.cuda()
model.model.norm = model.model.norm.cuda()
layers[0] = layers[0].cuda()
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros(
(batch.shape[0], model.seqlen, model.config.hidden_size), dtype=dtype, device="cuda"
)
cache = {'i': 0, 'attention_mask': None, "position_ids": None}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache['i']] = inp
cache['i'] += 1
if cache['attention_mask'] is None:
cache['attention_mask'] = kwargs['attention_mask']
cache['position_ids'] = kwargs['position_ids']
else:
cache['attention_mask'] = torch.cat((cache['attention_mask'], kwargs['attention_mask']), dim=0)
cache['position_ids'] = torch.cat((cache['position_ids'], kwargs['position_ids']), dim=0)
raise ValueError
layers[0] = Catcher(layers[0])
for j in range(batch.shape[0]):
try:
model(batch[j].unsqueeze(0).cuda())
except ValueError:
pass
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
model.model.embed_tokens = model.model.embed_tokens.cpu()
model.model.norm = model.model.norm.cpu()
torch.cuda.empty_cache()
attention_masks = cache['attention_mask']
position_ids = cache['position_ids']
for i in range(len(layers)):
layer = layers[i].cuda()
outs = layer(inps, attention_mask=attention_masks, position_ids=position_ids)[0]
layers[i] = layer.cpu()
inps = outs
torch.cuda.empty_cache()
hidden_states = norm(outs)
lm_logits = lm_head(hidden_states)
if torch.isfinite(lm_logits).all():
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = batch[:, 1:].contiguous().cuda()
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
nlls.append(loss)
else:
print("warning: nan or inf in lm_logits")
ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
ppls[dataset] = ppl
print("PPL after pruning: {}".format(ppls))
print("Weight Memory: {} MiB\n".format(torch.cuda.memory_allocated()/1024/1024))
@torch.no_grad()
def eff_eval(model, tokenizer, dataset='wikitext2', original_len=4, generated_len=128, batch_size=1, device="cuda"):
model.eval()
throughput = 0
token_num = 0
end_memory = 0
num_batches_to_fetch = 10
test_loader = get_test_data(dataset, tokenizer, seq_len=original_len, batch_size = batch_size)
weight_memory = torch.cuda.memory_allocated()
for batch_idx, batch_data in enumerate(itertools.islice(test_loader, num_batches_to_fetch)):
batch = batch_data.to(device)
token_num += batch.shape[0] * generated_len
torch.cuda.empty_cache()
start_memory = torch.cuda.memory_allocated()
torch.cuda.reset_peak_memory_stats(0)
torch.cuda.synchronize()
start_time = time.time()
generation_output = model.generate(
input_ids=batch,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
use_cache=True,
top_k=50,
max_length=original_len+generated_len,
top_p=0.95,
temperature=1,
)
torch.cuda.synchronize()
end_time = time.time()
end_memory = max(torch.cuda.max_memory_allocated(0), end_memory)
if torch.isfinite(generation_output[0]).all(): # check if the generation is successful since fp16 may cause nan
throughput += end_time - start_time
print("time: {}".format(end_time - start_time))
print("Total Memory: {} GB".format(end_memory/(1024 ** 3)))
print("Weight Memory: {} GB".format(weight_memory/(1024 ** 3)))
print("Activation Memory: {} GB".format((end_memory - start_memory)/(1024 ** 3)))
print("Throughput: {} tokens/sec".format(token_num / throughput))