-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmodels.py
114 lines (94 loc) · 3.47 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from torch import nn
import torch
from torchvision import models
import torchvision
from torch.nn import functional as F
def conv3x3(in_, out):
return nn.Conv2d(in_, out, 3, padding=1)
class ConvRelu(nn.Module):
def __init__(self, in_: int, out: int):
super(ConvRelu, self).__init__()
self.conv = conv3x3(in_, out)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
return x
class DecoderBlockLinkNet(nn.Module):
def __init__(self, in_channels, n_filters):
super().__init__()
self.relu = nn.ReLU(inplace=True)
# B, C, H, W -> B, C/4, H, W
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
self.norm1 = nn.BatchNorm2d(in_channels // 4)
# B, C/4, H, W -> B, C/4, 2 * H, 2 * W
self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=4,
stride=2, padding=1, output_padding=0)
self.norm2 = nn.BatchNorm2d(in_channels // 4)
# B, C/4, H, W -> B, C, H, W
self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
self.norm3 = nn.BatchNorm2d(n_filters)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.deconv2(x)
x = self.norm2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.relu(x)
return x
class LinkNet34(nn.Module):
def __init__(self, num_classes=1, num_channels=3, pretrained=True):
super().__init__()
assert num_channels == 3
self.num_classes = num_classes
filters = [64, 128, 256, 512]
resnet = models.resnet34(pretrained=pretrained)
self.firstconv = resnet.conv1
self.firstbn = resnet.bn1
self.firstrelu = resnet.relu
self.firstmaxpool = resnet.maxpool
self.encoder1 = resnet.layer1
self.encoder2 = resnet.layer2
self.encoder3 = resnet.layer3
self.encoder4 = resnet.layer4
# Decoder
self.decoder4 = DecoderBlockLinkNet(filters[3], filters[2])
self.decoder3 = DecoderBlockLinkNet(filters[2], filters[1])
self.decoder2 = DecoderBlockLinkNet(filters[1], filters[0])
self.decoder1 = DecoderBlockLinkNet(filters[0], filters[0])
# Final Classifier
self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2)
self.finalrelu1 = nn.ReLU(inplace=True)
self.finalconv2 = nn.Conv2d(32, 32, 3)
self.finalrelu2 = nn.ReLU(inplace=True)
self.finalconv3 = nn.Conv2d(32, num_classes, 2, padding=1)
# noinspection PyCallingNonCallable
def forward(self, x):
# Encoder
x = self.firstconv(x)
x = self.firstbn(x)
x = self.firstrelu(x)
x = self.firstmaxpool(x)
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
# Decoder with Skip Connections
d4 = self.decoder4(e4) + e3
d3 = self.decoder3(d4) + e2
d2 = self.decoder2(d3) + e1
d1 = self.decoder1(d2)
# Final Classification
f1 = self.finaldeconv1(d1)
f2 = self.finalrelu1(f1)
f3 = self.finalconv2(f2)
f4 = self.finalrelu2(f3)
f5 = self.finalconv3(f4)
if self.num_classes > 1:
x_out = F.log_softmax(f5, dim=1)
else:
x_out = f5
return x_out