-
Notifications
You must be signed in to change notification settings - Fork 266
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1039 from zhangxin86/main
第七周作业
- Loading branch information
Showing
9 changed files
with
16,949 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
""" | ||
配置参数信息 | ||
""" | ||
|
||
Config = { | ||
# 路径设置 | ||
"model_path": "output", | ||
"original_data_path": "文本分类练习.csv", | ||
"train_data_path": "data/new_train_tag_news.json", | ||
"valid_data_path": "data/new_valid_tag_news.json", | ||
"vocab_path": "chars.txt", | ||
"pretrain_model_path": r"E:\nlp22\week06\bert-base-chinese", | ||
# label设置 | ||
"index_to_label": {1: "好评", 0: "差评"}, | ||
# 模型设置 | ||
"model_type": "bert", | ||
"max_length": 30, | ||
"hidden_size": 256, | ||
"kernel_size": 3, | ||
"num_layers": 2, | ||
"epoch": 10, | ||
"batch_size": 128, | ||
"pooling_style": "max", | ||
"optimizer": "adam", | ||
"learning_rate": 1e-3, | ||
"seed": 987 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# -*- coding: utf-8 -*-.. | ||
import torch | ||
from loader import load_data | ||
|
||
|
||
class Evaluator: | ||
""" | ||
模型效果测试 | ||
""" | ||
|
||
def __init__(self, config, model, logger): | ||
self.config = config | ||
self.model = model | ||
self.logger = logger | ||
self.valid_data = load_data(config["valid_data_path"], config, shuffle=False) | ||
# 用于存储测试结果 | ||
self.stats_dict = {"correct": 0, "wrong": 0} | ||
|
||
def eval(self, epoch): | ||
self.logger.info(f"开始测试第{epoch}轮模型效果:") | ||
self.model.eval() | ||
# 清空上一轮结果 | ||
self.stats_dict = {"correct": 0, "wrong": 0} | ||
for index, batch_data in enumerate(self.valid_data): | ||
if torch.cuda.is_available(): | ||
batch_data = [d.cuda() for d in batch_data] | ||
input_ids, labels = batch_data | ||
with torch.no_grad(): | ||
pre_results = self.model(input_ids) | ||
self.write_stats(labels, pre_results) | ||
acc = self.show_stats() | ||
return acc | ||
|
||
def write_stats(self, labels, pred_results): | ||
assert len(labels) == len(pred_results) | ||
for true_label, pred_label in zip(labels, pred_results): | ||
pred_label = torch.argmax(pred_label) | ||
if int(pred_label) == int(true_label): | ||
self.stats_dict["correct"] += 1 | ||
else: | ||
self.stats_dict["wrong"] += 1 | ||
return | ||
|
||
def show_stats(self): | ||
correct = self.stats_dict["correct"] | ||
wrong = self.stats_dict["wrong"] | ||
self.logger.info("预测集合条目总量:%d" % (correct + wrong)) | ||
self.logger.info("预测正确条目:%d,预测错误条目:%d" % (correct, wrong)) | ||
self.logger.info("预测准确率:%f" % (correct / (correct + wrong))) | ||
self.logger.info("--------------------") | ||
return correct / (correct + wrong) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import json | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
from transformers import BertTokenizer | ||
|
||
import config | ||
|
||
|
||
class DataGenerator: | ||
""" | ||
数据加载 | ||
""" | ||
|
||
def __init__(self, data_path, config): | ||
self.data = None | ||
self.config = config | ||
self.path = data_path | ||
self.index_to_label = config["index_to_label"] | ||
self.label_to_index = dict((y, x) for x, y in self.index_to_label.items()) | ||
self.config["class_num"] = len(self.index_to_label) | ||
if self.config["model_type"] == "bert": | ||
self.tokenizer = BertTokenizer.from_pretrained(config["pretrain_model_path"]) | ||
self.vocab = load_vocab(config["vocab_path"]) | ||
self.config["vocab_size"] = len(self.vocab) | ||
self.load() | ||
|
||
def load(self): | ||
self.data = [] | ||
with open(self.path, encoding="utf8") as f: | ||
for line in f: | ||
if line.startswith("0,"): | ||
label = 0 | ||
elif line.startswith("1,"): | ||
label = 1 | ||
else: | ||
continue | ||
title = line[2:] | ||
if self.config["model_type"] == "bert": | ||
title_index = self.tokenizer.encode(title, max_length=self.config["max_length"], | ||
pad_to_max_length=True) | ||
else: | ||
title_index = self.encode_sentence(title) | ||
title_index = torch.LongTensor(title_index) | ||
label_index = torch.LongTensor([label]) | ||
self.data.append([title_index, label_index]) | ||
return | ||
|
||
def encode_sentence(self, text): | ||
input_id = [] | ||
for char in text: | ||
input_id.append(self.vocab.get(char, self.vocab["[UNK]"])) | ||
input_id = self.padding(input_id) | ||
return input_id | ||
|
||
def padding(self, input_id): | ||
input_id = input_id[:self.config["max_length"]] | ||
input_id += [0] * (self.config["max_length"] - len(input_id)) | ||
return input_id | ||
|
||
|
||
def load_vocab(vocab_path): | ||
""" | ||
加载词表,格式{词: index} | ||
""" | ||
token_dict = {} | ||
with open(vocab_path, encoding="utf8") as f: | ||
for index, line in enumerate(f): | ||
token = line.strip() | ||
# 0留给padding位置,所以从1开始 | ||
token_dict[token] = index + 1 | ||
return token_dict | ||
|
||
|
||
def load_data(data_path, config, shuffle=True): | ||
""" | ||
用torch自带的DataLoader类封装数据 | ||
""" | ||
dg = DataGenerator(data_path, config) | ||
return DataLoader(dg.data, batch_size=config["batch_size"], shuffle=shuffle) | ||
|
||
|
||
if __name__ == '__main__': | ||
from config import Config | ||
|
||
print(DataGenerator(Config['valid_data_path'], Config).data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# -*- coding: utf-8 -*- | ||
import torch | ||
import random | ||
import os | ||
import numpy as np | ||
import logging | ||
from config import Config | ||
from model import TorchModel, choose_optimizer | ||
from evaluate import Evaluator | ||
from loader import load_data | ||
|
||
""" | ||
模型训练主程序 | ||
""" | ||
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
logger = logging.getLogger(__name__) | ||
|
||
seed = Config["seed"] | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
|
||
def main(config): | ||
# 创建保存模型的目录 | ||
if not os.path.isdir(config["model_path"]): | ||
os.mkdir(config["model_path"]) | ||
# 加载训练数据 | ||
train_data = load_data(config["train_data_path"], config) | ||
# 加载模型 | ||
model = TorchModel(config) | ||
# 标识是否使用gpu | ||
cuda_flag = torch.cuda.is_available() | ||
if cuda_flag: | ||
logger.info("gpu可以使用,迁移模型至gpu") | ||
model = model.cuda() | ||
# 加载优化器 | ||
optimizer = choose_optimizer(config, model) | ||
# 加载效果测试类 | ||
evaluator = Evaluator(config, model, logger) | ||
# 训练 | ||
for epoch in range(config["epoch"]): | ||
epoch += 1 | ||
model.train() | ||
logger.info(f"epoch {epoch} begin") | ||
train_loss = [] | ||
for index, batch_data in enumerate(train_data): | ||
if cuda_flag: | ||
batch_data = [d.cuda() for d in batch_data] | ||
optimizer.zero_grad() | ||
input_ids, labels = batch_data | ||
loss = model(input_ids, labels) | ||
loss.backward() | ||
optimizer.step() | ||
train_loss.append(loss.item()) | ||
# if index % int(len(train_data) / 2) == 0: | ||
# logger.info("batch loss %f" % loss) | ||
# logger.info("epoch average loss: %f" % np.mean(train_loss)) | ||
acc = evaluator.eval(epoch) | ||
return acc | ||
|
||
|
||
if __name__ == '__main__': | ||
main(Config) | ||
|
||
# 对比所有模型 | ||
# 超参数网格搜索 | ||
for model in ["bert", "lstm"]: | ||
Config["model_type"] = model | ||
for lr in [1e-3, 1e-4]: | ||
Config["learning_rate"] = lr | ||
for hidden_size in [128]: | ||
Config["hidden_size"] = hidden_size | ||
for batch_size in [64, 128]: | ||
for pooling_style in ["avg", "max"]: | ||
Config["pooling_style"] = pooling_style | ||
logger.info(f"最后一轮准确率: {main(Config)}, 当前配置: {Config}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.optim import Adam, SGD | ||
from transformers import BertModel | ||
|
||
|
||
class TorchModel(nn.Module): | ||
""" | ||
建立网络模型结构 | ||
""" | ||
|
||
def __init__(self, config): | ||
super(TorchModel, self).__init__() | ||
hidden_size = config["hidden_size"] | ||
vocab_size = config["vocab_size"] + 1 # todo 这里为什么加1? | ||
class_num = config["class_num"] | ||
model_type = config["model_type"] | ||
num_layers = config["num_layers"] | ||
self.use_bert = False | ||
self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0) | ||
if model_type == "fast_text": | ||
self.encoder = lambda x: x | ||
elif model_type == "lstm": | ||
self.encoder = nn.LSTM(hidden_size, hidden_size, num_layers=num_layers, batch_first=True) | ||
elif model_type == "gru": | ||
self.encoder = nn.GRU(hidden_size, hidden_size, num_layers=num_layers, batch_first=True) | ||
elif model_type == "rnn": | ||
self.encoder = nn.RNN(hidden_size, hidden_size, num_layers=num_layers, batch_first=True) | ||
elif model_type == "bert": | ||
self.use_bert = True | ||
self.encoder = BertModel.from_pretrained(config["pretrain_model_path"], return_dict=False) | ||
hidden_size = self.encoder.config.hidden_size | ||
self.classify = nn.Linear(hidden_size, class_num) | ||
self.pooling_style = config["pooling_style"] | ||
# loss 采用交叉熵计算损失 | ||
self.loss = nn.functional.cross_entropy | ||
|
||
def forward(self, x, target=None): | ||
if self.use_bert: | ||
x = self.encoder(x) | ||
else: | ||
x = self.embedding(x) | ||
x = self.encoder(x) | ||
|
||
if isinstance(x, tuple): | ||
x = x[0] | ||
|
||
if self.pooling_style == "max": | ||
self.pooling_layer = nn.MaxPool1d(x.shape[1]) | ||
else: | ||
self.pooling_layer = nn.AvgPool1d(x.shape[1]) | ||
x = self.pooling_layer(x.transpose(1, 2)).squeeze() | ||
predict = self.classify(x) | ||
if target is not None: | ||
return self.loss(predict, target.squeeze()) | ||
else: | ||
return predict | ||
|
||
|
||
def choose_optimizer(config, model): | ||
optimizer = config["optimizer"] | ||
learning_rate = config["learning_rate"] | ||
if optimizer == "adam": | ||
return Adam(model.parameters(), lr=learning_rate) | ||
elif optimizer == "sgd": | ||
return SGD(model.parameters(), lr=learning_rate) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# -*- coding: utf-8 -*- | ||
import random | ||
from config import Config | ||
|
||
""" | ||
切割训练集和验证集 | ||
""" | ||
|
||
def split_file(path): | ||
with open(path, 'r', encoding='utf8') as f: | ||
lines = f.readlines()[1:] | ||
random.shuffle(lines) | ||
total_lines_num = len(lines) | ||
train_lines_num = int(0.8 * total_lines_num) | ||
|
||
train_lines = lines[:train_lines_num] | ||
valid_lines = lines[train_lines_num:] | ||
|
||
with open(Config['train_data_path'], 'w', encoding='utf8') as f_train: | ||
f_train.writelines(train_lines) | ||
|
||
with open(Config['valid_data_path'], 'w', encoding='utf8') as f_valid: | ||
f_valid.writelines(valid_lines) | ||
|
||
|
||
split_file(Config['original_data_path']) |
Oops, something went wrong.