forked from fastnlp/fastNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_lstm_att.py
61 lines (43 loc) · 1.94 KB
/
train_lstm_att.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
from fastNLP.io.data_loader import IMDBLoader
from fastNLP.embeddings import StaticEmbedding
from model.lstm_self_attention import BiLSTM_SELF_ATTENTION
from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP import Trainer
from torch.optim import Adam
class Config():
train_epoch= 10
lr=0.001
num_classes=2
hidden_dim=256
num_layers=1
attention_unit=256
attention_hops=1
nfc=128
task_name = "IMDB"
datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"}
save_model_path="./result_IMDB_test/"
opt=Config()
# load data
dataloader=IMDBLoader()
datainfo=dataloader.process(opt.datapath)
# print(datainfo.datasets["train"])
# print(datainfo)
# define model
vocab=datainfo.vocabs['words']
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True)
model=BiLSTM_SELF_ATTENTION(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, attention_unit=opt.attention_unit, attention_hops=opt.attention_hops, nfc=opt.nfc)
# define loss_function and metrics
loss=CrossEntropyLoss()
metrics=AccuracyMetric()
optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr)
def train(datainfo, model, optimizer, loss, metrics, opt):
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1,
n_epochs=opt.train_epoch, save_path=opt.save_model_path)
trainer.train()
if __name__ == "__main__":
train(datainfo, model, optimizer, loss, metrics, opt)