-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdistillBERT.py
50 lines (35 loc) · 1.54 KB
/
distillBERT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import pathlib
import numpy as np
import scipy.sparse
import scipy.io
import pandas as pd
import transformers
import torch
def chunks(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
yield l[i:i + n]
def fetch_vectors(string_list, batch_size=8):
# inspired by https://jalammar.github.io/a-visual-guide-to-using-bert-for-the-first-time/
DEVICE = torch.device("cuda")
tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
bert_model = transformers.DistilBertModel.from_pretrained("distilbert-base-uncased")
bert_model.to(DEVICE)
fin_features = []
max_len = 64
for data in chunks(string_list, batch_size):
tokenized = []
for x in data:
x = " ".join(x.strip().split()[:300])
tok = tokenizer.encode(x, add_special_tokens=True)
tokenized.append(tok[:max_len])
padded = np.array([i + [0] * (max_len - len(i)) for i in tokenized])
attention_mask = np.where(padded != 0, 1, 0)
input_ids = torch.tensor(padded).to(DEVICE)
attention_mask = torch.tensor(attention_mask).to(DEVICE)
with torch.no_grad():
last_hidden_states = bert_model(input_ids, attention_mask=attention_mask)
features = last_hidden_states[0][:, 0, :].cpu().numpy()
fin_features.append(features)
fin_features = np.vstack(fin_features)
return fin_features