-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTextCNN.py
61 lines (48 loc) · 2.3 KB
/
TextCNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class TextCNN(nn.Module):
def __init__(self, config):
super(TextCNN, self).__init__()
self.kernel_num = config.kernel_num # Number of conv kernels
self.embedding_length = config.wordvec_dim
self.Ks = config.kernel_sizes
self.word_embeddings = nn.Embedding(400000, self.embedding_length) # Embedding layer
self.word_embeddings = self.word_embeddings.from_pretrained(config.weight, freeze=False) # Load pretrianed word embedding, and fine-tuing
self.convs = nn.ModuleList([nn.Conv2d(1, config.kernel_num, (K, config.wordvec_dim), bias=True) for K in self.Ks])
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(0.5)
self.fc = nn.Linear(len(self.Ks)*self.kernel_num, config.label_num)
self.bn1 = nn.BatchNorm1d(len(self.Ks)*self.kernel_num)
# self.fc1 = nn.Linear(len(self.Ks)*self.kernel_num, fc1out)
# self.bn2 = nn.BatchNorm1d(fc1out)
# self.fc2 = nn.Linear(fc1out, config.label_num)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
# (batch_size, batch_dim)
# Embedding
x = self.word_embeddings(x) # (batch_size, batch_dim, embedding_len)
x = x.unsqueeze(1) # (batch_size, 1, batch_dim, embedding_len)
# print(x.size())
# Conv and relu
x = [self.relu(conv(x)).squeeze(3) for conv in self.convs] # [(batch_size, kernel_num, batch_dim-K+1), ...]*len(Ks)
# print([i.shape for i in x])
# max-over-time pooling (actually merely max)
x = [F.max_pool1d(xi, xi.size()[2]).squeeze(2) for xi in x] # [(batch_size, kernel_num), ...]*len(Ks)
# print([i.shape for i in x])
x = torch.cat(x, dim=1) # (batch_size, kernel_num*len(Ks))
x = self.bn1(x) # Batch Normaliztion
x = self.dropout(x) # (batch_size, len(Ks)*kernel_num), dropout
# print(x.size())
x = self.fc(x) # (batch_size, label_num)
# print(x.size())
# x = self.fc1(x)
# x = self.bn2(x)
# x = self.fc2(x)
logit = self.softmax(x)
return logit
if __name__ == '__main__':
pass