Skip to content

Neural models for Text Classification in Tensorflow, such as cnn, dpcnn, fasttext, bert ...

Notifications You must be signed in to change notification settings

liyibo/text-classification-demos

Repository files navigation

Text classification demos

Tensorflow 环境下,不同的神经网络模型对中文文本进行分类,本文中的 demo 都是字符级别的文本分类(增加了word-based 的统计结果),简化了文本分类的流程,字符级别的分类在有些任务上的效果可能不好,需要结合实际情况添加自定义的分词模块。

数据集

下载地址: https://pan.baidu.com/s/1hugrfRu 密码: qfud

使用 THUCNews 的一个子集进行训练与测试,使用了其中的 10 个分类,每个分类 6500 条数据。

类别如下:

体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐

数据集划分如下:

训练集: 5000 * 10
验证集: 500 * 10
测试集: 1000 * 10

具体介绍请参考:text-classification-cnn-rnn

分类效果

  • char-based
model fasttext cnn rnn rcnn han dpcnn bert
val_acc 92.92 93.56 93.56 94.36 93.94 93.70 97.84
test_acc 93.15 94.57 94.37 95.53 93.65 94.87 96.93
  • word-based
model fasttext cnn rnn rcnn han dpcnn bert
val_acc 95.52 95.28 93.10 95.60 95.10 95.68 -
test_acc 95.34 95.77 94.05 96.36 95.66 95.97 -

模型介绍

1、FastText

fasttext_model.py 文件为训练和测试 fasttext 模型的代码

图1 FastText 模型结构图

本代码简化了 fasttext 模型的结构,模型结构非常简单,运行速度简直飞快,模型准确率也不错,可根据实际需要优化模型结构

2、TextCNN

cnn_model.py 文件为训练和测试 TextCNN 模型的代码

图2 TextCNN 模型结构图

本代码实现了 TextCNN 模型的结构,通过 3 个不同大小的卷积核,对输入文本进一维卷积,分别 pooling 三个卷积之后的 feature, 拼接到一起,然后进行 dense 操作,最终输出模型结果。可实现速度和精度之间较好的折中。

3、RNN

rnn_model.py 文件为训练和测试 TextCNN 模型的代码

图8 TextRNN 模型结构图

本代码实现了 TextRNN 模型的结构,对输入序列进行embedding,然后输入两层的 rnn_cell中学习序列特征,取最后一个 word 的 state 作为进行后续的 fc 操作,最终输出模型结果。

4、RCNN

rcnn_model.py 文件为训练和测试 RCNN 模型的代码

图3 RCNN 模型结构图

Recurrent Convolutional Neural Network for Text Classification, 在学习 word representations 时候,同时采用了 rnn 结构来学习 word 的上下文,虽然模型名称为 RCNN,但并没有显式的存在卷积操作。

1、采用双向lstm学习 word 的上下文

   c_left = tf.concat([tf.zeros(shape), output_fw[:, :-1]], axis=1, name="context_left")
   c_right = tf.concat([output_bw[:, 1:], tf.zeros(shape)], axis=1, name="context_right")
   word_representation = tf.concat([c_left, embedding_inputs, c_right], axis=2, name="last")

2、pooling + softmax

word_representation 的维度是 batch_size * seq_length * 2 * context_dim + embedding_dim

在 seq_length 维度进行 max pooling,然后进行 fc 操作就可以进行分类了,可以将该网络看成是 fasttext 的改进版本

5、HAN

han_model.py 文件为训练和测试 HAN 模型的代码

图4 HAN 模型结构图

HAN 为 Hierarchical Attention Networks,将待分类文本,分为一定数量的句子,分别在 word level 和 sentence level 进行 encoder 和 attention 操作,从而实现对较长文本的分类。

本文是按照句子长度将文本分句的,实际操作中可按照标点符号等进行分句,理论上效果能好一点。

  • 1、对文本进行分句

    对每个句子进行双向lstm编码

    batch_size = 64, seq_length = 600, sent_num = 10, emb_size = 128, lstm_hid_dim = 256

    数据维度变化:64 * 600 * 128 --- (64*10) * 60 * 128 --- (64*10) * 60 * 512

  • 2、word level attention

图4 attention

(1) 将输入的lstm编码结果做一次非线性变换,可以看做是输入编码的hidden representation, shape = (64*10) * 60 * 256

(2) 将 hidden representation 与一个学习得到的 word level context vector 的相似性进行 softmax,得到每个单词在句子中的权重

(3) 对输入的lstm 编码进行加权求和,得到句子的向量表示

数据维度变化:(64*10) * 60 * 512 --- (64*10) * 512

  • 3、得到每个句子的向量表示

  • 4、sentence level attention

    与 word level attention 过程一样,只是该层是句子级别的attention

    数据维度变化:64 * 10 * 512 --- 64 * 512

  • 5、得到 document 的向量表示

  • 6、dence + softmax

6、DPCNN

dpcnn_model.py 文件为训练和测试 DPCNN 模型的代码

图5 DPCNN 模型结构图

DPCNN 通过卷积和残差连接增加了以往用于文本分类 CNN 网络的深度,可以有效提取文本中的远程关系特征,并且复杂度不高,实验表名,效果比以往的 CNN 结构要好一点。

  • region_embedding: word_embedding 之后进行的 ngram 卷积结果

7、BERT

bert_model.py 文件为训练和测试 BERT 模型的代码

google官方提供用于文本分类的demo写的比较抽象,所以本文基于 google 提供的代码和初始化模型,重写了文本分类模型的训练和测试代码,bert 分类模型在小数据集下效果很好,通过较少的迭代次数就能得到很好的效果,但是训练和测试速度较慢,这点不如基于 CNN 的网络结构。

bert_model.py 将训练数据和验证数据存储为 tfrecord 文件,然后进行训练

由于 bert 提供的预训练模型较大,需要自己去 google-research/bert 中下载预训练好的模型,本实验采用的是 "BERT-Base, Chinese" 模型。

图6 BERT 输入数据格式

图7 BERT 下游任务介绍

参考

About

Neural models for Text Classification in Tensorflow, such as cnn, dpcnn, fasttext, bert ...

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages