-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathmodels.py
483 lines (428 loc) · 15.5 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import learn2learn as l2l
def conv_block(in_channels, out_channels, **kwargs):
return nn.Sequential(
OrderedDict(
[
("conv", nn.Conv2d(in_channels, out_channels, **kwargs)),
(
"norm",
nn.BatchNorm2d(
out_channels, momentum=1.0, track_running_stats=False
),
),
("relu", nn.ReLU()),
("pool", nn.MaxPool2d(2)),
]
)
)
class ConvModel(nn.Module):
"""4-layer Convolutional Neural Network architecture from [1].
Parameters
----------
in_channels : int
Number of channels for the input images.
out_features : int
Number of classes (output of the model).
hidden_size : int (default: 64)
Number of channels in the intermediate representations.
feature_size : int (default: 64)
Number of features returned by the convolutional head.
References
----------
.. [1] Finn C., Abbeel P., and Levine, S. (2017). Model-Agnostic Meta-Learning
for Fast Adaptation of Deep Networks. International Conference on
Machine Learning (ICML) (https://arxiv.org/abs/1703.03400)
"""
def __init__(self, in_channels, out_features, hidden_size=64, feature_size=64):
super(ConvModel, self).__init__()
self.in_channels = in_channels
self.out_features = out_features
self.hidden_size = hidden_size
self.feature_size = feature_size
self.features = nn.Sequential(
OrderedDict(
[
(
"layer1",
conv_block(
in_channels,
hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=True,
),
),
(
"layer2",
conv_block(
hidden_size,
hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=True,
),
),
(
"layer3",
conv_block(
hidden_size,
hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=True,
),
),
(
"layer4",
conv_block(
hidden_size,
hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=True,
),
),
]
)
)
self.classifier = nn.Linear(feature_size, out_features, bias=True)
def forward(self, inputs):
features = self.features(inputs)
features = features.view((features.size(0), -1))
logits = self.classifier(features)
return logits
def ConvOmniglot(out_features, hidden_size=64):
return ConvModel(1, out_features, hidden_size=hidden_size, feature_size=hidden_size)
def ConvMiniImagenet(out_features, hidden_size=64):
return ConvModel(
3, out_features, hidden_size=hidden_size, feature_size=5 * 5 * hidden_size
)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
drop_rate=0.0,
drop_block=False,
block_size=1,
):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes, momentum=1.0, track_running_stats=False)
self.relu = nn.LeakyReLU(0.1)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes, momentum=1.0, track_running_stats=False)
self.conv3 = conv3x3(planes, planes)
self.bn3 = nn.BatchNorm2d(planes, momentum=1.0, track_running_stats=False)
self.maxpool = nn.MaxPool2d(stride)
self.downsample = downsample
self.stride = stride
self.drop_rate = drop_rate
self.num_batches_tracked = 0
self.drop_block = drop_block
self.block_size = block_size
self.DropBlock = DropBlock(block_size=self.block_size)
def forward(self, x):
self.num_batches_tracked += 1
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
out = self.maxpool(out)
if self.drop_rate > 0:
if self.drop_block:
feat_size = out.size()[2]
keep_rate = max(
1.0 - self.drop_rate / 40000 * self.num_batches_tracked,
1.0 - self.drop_rate,
)
gamma = (
(1 - keep_rate)
/ self.block_size**2
* feat_size**2
/ (feat_size - self.block_size + 1) ** 2
)
out = self.DropBlock(out, gamma=gamma)
else:
out = F.dropout(
out,
p=self.drop_rate,
training=self.training,
inplace=True,
)
return out
class DropBlock(nn.Module):
def __init__(self, block_size):
super(DropBlock, self).__init__()
self.block_size = block_size
def forward(self, x, gamma):
if self.training:
batch_size, channels, height, width = x.shape
bernoulli = torch.distributions.Bernoulli(gamma)
mask = bernoulli.sample(
(
batch_size,
channels,
height - (self.block_size - 1),
width - (self.block_size - 1),
)
).to(x.device)
block_mask = self._compute_block_mask(mask)
countM = (
block_mask.size(0)
* block_mask.size(1)
* block_mask.size(2)
* block_mask.size(3)
)
count_ones = block_mask.sum()
return block_mask * x * (countM / count_ones)
else:
return x
def _compute_block_mask(self, mask):
left_padding = int((self.block_size - 1) / 2)
right_padding = int(self.block_size / 2)
batch_size, channels, height, width = mask.shape
non_zero_idxs = mask.nonzero(as_tuple=False)
nr_blocks = non_zero_idxs.shape[0]
offsets = torch.stack(
[
torch.arange(self.block_size)
.view(-1, 1)
.expand(self.block_size, self.block_size)
.reshape(-1),
torch.arange(self.block_size).repeat(self.block_size),
]
).t()
offsets = torch.cat(
(torch.zeros(self.block_size**2, 2).long(), offsets.long()),
dim=1,
).to(mask.device)
if nr_blocks > 0:
non_zero_idxs = non_zero_idxs.repeat(self.block_size**2, 1)
offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
offsets = offsets.long()
block_idxs = non_zero_idxs + offsets
padded_mask = F.pad(
mask, (left_padding, right_padding, left_padding, right_padding)
)
padded_mask[
block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]
] = 1.0
else:
padded_mask = F.pad(
mask, (left_padding, right_padding, left_padding, right_padding)
)
block_mask = 1 - padded_mask
return block_mask
class ResNet12Backbone(nn.Module):
def __init__(
self,
hidden_size=64,
avg_pool=True, # Set to False for 16000-dim embeddings
wider=True, # True mimics MetaOptNet, False mimics TADAM
embedding_dropout=0.0, # dropout for embedding
dropblock_dropout=0.1, # dropout for residual layers
dropblock_size=5,
channels=3,
):
super(ResNet12Backbone, self).__init__()
self.inplanes = channels
block = BasicBlock
if wider:
num_filters = [
hidden_size * 1,
int(hidden_size * 2.5),
hidden_size * 5,
hidden_size * 10,
]
else:
num_filters = [
hidden_size * 1,
hidden_size * 2,
hidden_size * 4,
hidden_size * 8,
]
self.layer1 = self._make_layer(
block,
num_filters[0],
stride=2,
dropblock_dropout=dropblock_dropout,
)
self.layer2 = self._make_layer(
block,
num_filters[1],
stride=2,
dropblock_dropout=dropblock_dropout,
)
self.layer3 = self._make_layer(
block,
num_filters[2],
stride=2,
dropblock_dropout=dropblock_dropout,
drop_block=True,
block_size=dropblock_size,
)
self.layer4 = self._make_layer(
block,
num_filters[3],
stride=2,
dropblock_dropout=dropblock_dropout,
drop_block=True,
block_size=dropblock_size,
)
if avg_pool:
self.avgpool = nn.AvgPool2d(5, stride=1)
else:
self.avgpool = l2l.nn.Lambda(lambda x: x)
self.flatten = l2l.nn.Flatten()
self.embedding_dropout = embedding_dropout
self.keep_avg_pool = avg_pool
self.dropout = nn.Dropout(p=self.embedding_dropout, inplace=False)
self.dropblock_dropout = dropblock_dropout
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight,
mode="fan_out",
nonlinearity="leaky_relu",
)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(
self,
block,
planes,
stride=1,
dropblock_dropout=0.0,
drop_block=False,
block_size=1,
):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=1,
bias=False,
),
nn.BatchNorm2d(
planes * block.expansion, momentum=1.0, track_running_stats=False
),
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
dropblock_dropout,
drop_block,
block_size,
)
)
for _ in range(2):
layers.append(
block(
planes, planes, 1, None, dropblock_dropout, drop_block, block_size
)
)
self.inplanes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.dropout(x)
return x
class ResNet12(nn.Module):
"""
[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models/resnet12.py)
**Description**
The 12-layer residual network from Mishra et al, 2017.
The code is adapted from [Lee et al, 2019](https://github.com/kjunelee/MetaOptNet/)
who share it under the Apache 2 license.
Instantiate `ResNet12Backbone` if you only need the feature extractor.
List of changes:
* Rename ResNet to ResNet12.
* Small API modifications.
* Fix code style to be compatible with PEP8.
* Support multiple devices in DropBlock
**References**
1. Mishra et al. 2017. “A Simple Neural Attentive Meta-Learner.” ICLR 18.
2. Lee et al. 2019. “Meta-Learning with Differentiable Convex Optimization.” CVPR 19.
3. Lee et al's code: [https://github.com/kjunelee/MetaOptNet/](https://github.com/kjunelee/MetaOptNet/)
4. Oreshkin et al. 2018. “TADAM: Task Dependent Adaptive Metric for Improved Few-Shot Learning.” NeurIPS 18.
**Arguments**
* **output_size** (int) - The dimensionality of the output (eg, number of classes).
* **hidden_size** (list, *optional*, default=640) - Size of the embedding once features are extracted.
(640 is for mini-ImageNet; used for the classifier layer)
* **avg_pool** (bool, *optional*, default=True) - Set to False for the 16k-dim embeddings of Lee et al, 2019.
* **wider** (bool, *optional*, default=True) - True uses (64, 160, 320, 640) filters akin to Lee et al, 2019.
False uses (64, 128, 256, 512) filters, akin to Oreshkin et al, 2018.
* **embedding_dropout** (float, *optional*, default=0.0) - Dropout rate on the flattened embedding layer.
* **dropblock_dropout** (float, *optional*, default=0.1) - Dropout rate for the residual layers.
* **dropblock_size** (int, *optional*, default=5) - Size of drop blocks.
**Example**
~~~python
model = ResNet12(output_size=ways, hidden_size=1600, avg_pool=False)
~~~
"""
def __init__(
self,
output_size,
hidden_size=64, # mini-ImageNet images, used for the classifier
avg_pool=True, # Set to False for 16000-dim embeddings
wider=True, # True mimics MetaOptNet, False mimics TADAM
embedding_dropout=0.0, # dropout for embedding
dropblock_dropout=0.0, # dropout for residual layers
dropblock_size=5,
channels=3,
):
super(ResNet12, self).__init__()
self.features = ResNet12Backbone(
hidden_size=hidden_size,
avg_pool=avg_pool,
wider=wider,
embedding_dropout=embedding_dropout,
dropblock_dropout=dropblock_dropout,
dropblock_size=dropblock_size,
channels=channels,
)
scale_factor = 10 if wider else 8
self.classifier = torch.nn.Linear(hidden_size * scale_factor, output_size)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x