-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCNN.py
44 lines (40 loc) · 1.45 KB
/
CNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 14 23:03:08 2019
@author: Thomas
"""
import torch.nn as nn
from torchcrf import CRF
class SimpleCNN(nn.Module):
# ToDo adapt CNN
def __init__(self,num_tags,inp,outp,dev):
super(SimpleCNN, self).__init__()
self.compLayer = nn.Sequential(
nn.Conv2d(inp, 128, kernel_size=(1,1), padding = (0,0)),
nn.LeakyReLU(),
nn.Dropout(0.3)
#nn.BatchNorm2d(128)
)
# self.attentionlayer = nn.TransformerEncoder(
# nn.TransformerEncoderLayer(d_model=128, nhead=1),
# num_layers=1
# )
#self.rnn = nn.LSTM(128,128, num_layers = 2)
self.layer1 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=(9,1), padding = (4,0)),
nn.LeakyReLU(),
nn.Dropout(0.3)
#nn.BatchNorm2d(64)
)
self.layer2 = nn.Sequential(
nn.Conv2d(64, outp, kernel_size=(5,1), padding=(2,0))
)
#self.crf = CRF(num_tags,batch_first=True).to(dev)
def forward(self, x):
out = self.compLayer(x.unsqueeze(3))
#out, (hn,cn) = self.rnn(out.squeeze(3).permute(2,0,1))
# out = self.attentionlayer(out.squeeze(dim = 3).permute(2,0,1))
out = self.layer1(out)#.permute(1,2,0).unsqueeze(3))
out = self.layer2(out)
out = out.squeeze(dim = 3)
return out