-
Notifications
You must be signed in to change notification settings - Fork 108
/
Copy pathmodel.py
31 lines (26 loc) · 1.31 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import BertModel, BertPreTrainedModel
class TextRCNN_Bert(BertPreTrainedModel):
def __init__(self, config):
super(TextRCNN_Bert, self).__init__(config)
self.bert = BertModel(config)
self.lstm = nn.LSTM(768, 64, bidirectional=True, batch_first=True)
self.fc = nn.Linear(64 * 2 + 768, 2)
def forward(self, context, mask, softmax=True):
outputs = self.bert(context, attention_mask=mask, output_hidden_states = True)
#last_hidden_state = outputs[0] # [batch_size, seq_len, 768]
hidden_states = outputs[2]
second_to_last_layer = hidden_states[-2] # [batch_size, seq_len, 768]
out, _ = self.lstm(second_to_last_layer) # [batch_size, seq_len, 64 * 2]
out = torch.cat((second_to_last_layer, out), 2) # [batch_size, seq_len, 64 * 2 + 768]
out = F.relu(out) # [batch_size, seq_len, 64 * 2 + 768]
out = out.permute(0, 2, 1) # [batch_size, 64 * 2 + 768, seq_len]
out = F.max_pool1d(out, out.size(2)).squeeze(2) # [batch_size, 64 * 2 + 768]
if softmax:
logit = F.log_softmax(self.fc(out), dim=1) # [batch_size, 2]
else:
logit = self.fc(out)
return logit