-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
101 lines (77 loc) · 4.03 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class ResBlock(nn.Module):
def __init__(self, input_channels, output_channels):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1, stride=2)
self.bn1 = nn.BatchNorm2d(output_channels)
self.conv2 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, padding=1, stride=1)
self.bn2 = nn.BatchNorm2d(output_channels)
self.res_conv = nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=1, padding=0, stride=2)
self.bn_res = nn.BatchNorm2d(output_channels)
def forward(self, x):
res = x
x = F.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
x = F.relu(x.add(self.bn_res(self.res_conv(res))))
return x
class StringNet(nn.Module):
def __init__(self, n_classes: int, seq_length: int, batch_size: int,
lstm_hidden_dim: int = 100, bidirectional: bool = True, lstm_layers: int = 2,
lstm_dropout: float = 0.5, fc2_dim: int = 100):
super(StringNet, self).__init__()
self.n_classes = n_classes
self.seq_length = seq_length
self.batch_size = batch_size
self.lstm_hidden_dim = lstm_hidden_dim
self.bidirectional = bidirectional
self.lstm_layers = lstm_layers
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=1, stride=1)
self.bn1 = nn.BatchNorm2d(64)
self.pool1 = nn.MaxPool2d(3, stride=1)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1)
self.bn3 = nn.BatchNorm2d(64)
self.res_block1 = ResBlock(64, 128)
self.res_block2 = ResBlock(128, 256)
self.res_block3 = ResBlock(256, 512)
self.lstm = nn.LSTM(3072, self.lstm_hidden_dim, num_layers=self.lstm_layers, bias=True,
dropout=lstm_dropout, bidirectional=self.bidirectional)
self.fc1 = nn.Linear(self.lstm_hidden_dim, fc2_dim)
self.dropout = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(fc2_dim, n_classes)
def init_hidden(self, input_length):
# The axes semantics are (num_layers * num_directions, minibatch_size, hidden_dim)
return (torch.zeros(self.lstm_layers * self.directions, input_length, self.lstm_hidden_dim).to(device),
torch.zeros(self.lstm_layers * self.directions, input_length, self.lstm_hidden_dim).to(device))
def forward(self, x):
current_batch_size = x.shape[0]
x = F.relu(self.bn1(self.conv1(x)))
x = res1 = self.pool1(x)
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x = F.relu(x.add(res1))
x = self.res_block1(x)
x = self.res_block2(x)
x = self.res_block3(x)
features = x.permute(3, 0, 1, 2).view(self.seq_length, current_batch_size, -1)
hidden = self.init_hidden(current_batch_size)
outs, _ = self.lstm(features, hidden)
if self.bidirectional:
outs = outs.view(self.seq_length, current_batch_size, 2, self.lstm_hidden_dim)
output_forward, output_backward = torch.chunk(outs, 2, 2)
output_forward = output_forward.view(self.seq_length, current_batch_size, -1)
output_backward = output_backward.view(self.seq_length, current_batch_size, -1)
outs = output_forward.add(output_backward)
# Decode the hidden state of the last time step
outs = self.fc1(outs)
outs = self.dropout(outs)
outs = self.fc2(outs)
outs = F.log_softmax(outs, 2)
return outs
@property
def directions(self) -> int:
return 2 if self.bidirectional else 1