-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLFUNet.py
178 lines (159 loc) · 9.62 KB
/
LFUNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# https://www.ingentaconnect.com/content/ben/cmir/2023/00000019/00000004/art00006
from torch import nn
from torch import cat
import numpy as np
import cfg
class LFUNet(nn.Module):
def __init__(self):
super().__init__()
n = 8
g = 8
# x0.0 => out chanel = 8
self.stage00 = nn.Sequential(
nn.Conv2d( 3, n, kernel_size=3, stride=1, padding=1, groups=1), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
nn.Conv2d( n, n, kernel_size=3, stride=1, padding=1, groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
)
# x1.0 => out chanel = 8*2 = 16
self.stage10 = nn.Sequential(
nn.Conv2d( n, 2*n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d(2*n), nn.ReLU(inplace=True),
nn.Conv2d(2*n, 2*n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d(2*n), nn.ReLU(inplace=True),
)
# x2.0 => out chanel = 8*4 = 32
self.stage20 = nn.Sequential(
nn.Conv2d(2*n, 4*n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d(4*n), nn.ReLU(inplace=True),
nn.Conv2d(4*n, 4*n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d(4*n), nn.ReLU(inplace=True),
)
# x3.0 => out chanel = 8*8 = 64
self.stage30 = nn.Sequential(
nn.Conv2d(4*n, 8*n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d(8*n), nn.ReLU(inplace=True),
nn.Conv2d(8*n, 8*n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d(8*n), nn.ReLU(inplace=True),
)
# x4.0 => out chanel = 8*16 = 128
self.stage40 = nn.Sequential(
nn.Conv2d( 8*n, 16*n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d(16*n), nn.ReLU(inplace=True),
nn.Conv2d(16*n, 16*n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d(16*n), nn.ReLU(inplace=True),
)
# x0.1
self.stage01 = nn.Sequential(
nn.Conv2d(2*n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
nn.Conv2d( n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
)
# x1.1
self.stage11 = nn.Sequential(
nn.Conv2d(3*n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
nn.Conv2d( n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
)
# x0.2
self.stage02 = nn.Sequential(
nn.Conv2d(4*n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
nn.Conv2d( n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
)
# x2.1
self.stage21 = nn.Sequential(
nn.Conv2d(4*n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
nn.Conv2d( n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
)
# x1.2
self.stage12 = nn.Sequential(
nn.Conv2d(5*n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
nn.Conv2d( n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
)
# x0.3
self.stage03 = nn.Sequential(
nn.Conv2d(6*n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
nn.Conv2d( n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
)
# x3.1
self.stage31 = nn.Sequential(
nn.Conv2d(5*n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
nn.Conv2d( n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
)
# x2.2
self.stage22 = nn.Sequential(
nn.Conv2d(6*n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
nn.Conv2d( n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
)
# x1.3
self.stage13 = nn.Sequential(
nn.Conv2d(7*n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
nn.Conv2d( n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
)
# x0.4
self.stage04 = nn.Sequential(
nn.Conv2d(8*n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
nn.Conv2d( n, n, kernel_size=3, stride=1, padding=1,groups=g), nn.BatchNorm2d( n), nn.ReLU(inplace=True),
)
# trans
self.stage0p = nn.Sequential(nn.Conv2d( n, n, kernel_size=3, stride=1, padding=1,groups=g))
self.stage1p = nn.Sequential(nn.Conv2d(2*n, n, kernel_size=3, stride=1, padding=1,groups=g))
self.stage2p = nn.Sequential(nn.Conv2d(4*n, n, kernel_size=3, stride=1, padding=1,groups=g))
self.stage3p = nn.Sequential(nn.Conv2d(8*n, n, kernel_size=3, stride=1, padding=1,groups=g))
self.stage4p = nn.Sequential(nn.Conv2d(16*n, n, kernel_size=3, stride=1, padding=1,groups=g))
self.stagep = self.stage0p
self.stage0zp = nn.Sequential(nn.Conv2d(n, n, kernel_size=3, stride=1, padding=1,groups=g))
self.stage1zp = nn.Sequential(nn.Conv2d(2 * n, n, kernel_size=3, stride=1, padding=1,groups=g))
self.stage2zp = nn.Sequential(nn.Conv2d(4 * n, n, kernel_size=3, stride=1, padding=1,groups=g))
self.stage3zp = nn.Sequential(nn.Conv2d(8 * n, n, kernel_size=3, stride=1, padding=1,groups=g))
self.stage4zp = nn.Sequential(nn.Conv2d(16 * n, n, kernel_size=3, stride=1, padding=1,groups=g))
self.stagezp = self.stage0zp
# down-sampling
self.stage2d = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
self.stage2zd = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
self.stage4zd = nn.Sequential(nn.MaxPool2d(kernel_size=4, stride=4))
self.stage8zd = nn.Sequential(nn.MaxPool2d(kernel_size=8, stride=8))
# up-sampling
self.stage2u = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
self.stage4u = nn.Sequential(nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True))
self.stage8u = nn.Sequential(nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True))
self.stage16u = nn.Sequential(nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True))
# output
self.stageout = nn.Sequential(
nn.Conv2d( n, cfg.num_class, 1 ,groups=1)
)
def forward(self, x):
s00 = self.stage00(x) # out chanel = 8
s10 = self.stage10(self.stage2d(s00)) # out chanel = 16
s01 = self.stage01(cat([self.stage0p(s00), # out chanel = 8
self.stage1p(self.stage2u(s10))], 1))
s20 = self.stage20(self.stage2d(s10))
s11 = self.stage11(cat([self.stage0zp(self.stage2zd(s00)),
self.stage1p(s10),
self.stage2p(self.stage2u(s20))], 1))
s02 = self.stage02(cat([self.stage0p(s00), self.stagezp(s01),
self.stage2zp(self.stage4u(s20)), self.stagep(self.stage2u(s11))], 1))
s30 = self.stage30(self.stage2d(s20))
s21 = self.stage21(cat([self.stage0zp(self.stage4zd(s00)), self.stage1zp(self.stage2zd(s10)),
self.stage2p(s20),
self.stage3p(self.stage2u(s30))], 1))
s12 = self.stage12(cat([self.stage0zp(self.stage2zd(s00)),
self.stage1p(s10), self.stagezp(s11),
self.stage3zp(self.stage4u(s30)), self.stagep(self.stage2u(s21))], 1))
s03 = self.stage03(
cat([self.stage0p(s00), self.stagezp(s01), self.stagezp(s02),
self.stage3zp(self.stage8u(s30)), self.stagezp(self.stage4u(s21)), self.stagep(self.stage2u(s12))], 1))
s40 = self.stage40(self.stage2d(s30))
s31 = self.stage31(
cat([self.stage0zp(self.stage8zd(s00)), self.stage1zp(self.stage4zd(s10)),
self.stage2zp(self.stage2zd(s20)),
self.stage3p(s30),
self.stage4p(self.stage2u(s40))], 1))
s22 = self.stage22(cat([self.stage0zp(self.stage4zd(s00)), self.stage1zp(self.stage2zd(s10)),
self.stage2p(s20), self.stagezp(s21),
self.stage4zp(self.stage4u(s40)), self.stagep(self.stage2u(s31))], 1))
s13 = self.stage13(
cat([self.stage0zp(self.stage2zd(s00)),
self.stage1p(s10), self.stagezp(s11), self.stagezp(s12),
self.stage4zp(self.stage8u(s40)), self.stagezp(self.stage4u(s31)), self.stagep(self.stage2u(s22))], 1))
s04 = self.stage04(
cat([self.stage0p(s00), self.stagezp(s01), self.stagezp(s02), self.stagezp(s03),
self.stage4zp(self.stage16u(s40)), self.stagezp(self.stage8u(s31)),
self.stagezp(self.stage4u(s22)), self.stagep(self.stage2u(s13))], 1))
out1 = self.stageout(s01)
out2 = self.stageout(s02)
out3 = self.stageout(s03)
out4 = self.stageout(s04)
# out = out1
# out = (out1 + out2) / 2
# out = (out1 + out2 + out3) / 3
out = (out1 + out2 + out3 + out4) / 4
return out