forked from LiYingwei/ShapeTextureDebiasedTraining
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaux_bn.py
48 lines (39 loc) · 1.66 KB
/
aux_bn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
The implementation of auxiliary batch normalization.
Proposed by Xie et al. Adversarial Examples Improve Image Recognition. CVPR 2020
"""
from functools import partial
import torch
from torch import nn
class MixBatchNorm2d(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(MixBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
self.aux_bn = nn.BatchNorm2d(num_features, eps=eps, momentum=momentum, affine=affine,
track_running_stats=track_running_stats)
self.batch_type = 'clean'
def forward(self, input):
if self.batch_type == 'adv':
input = self.aux_bn(input)
elif self.batch_type == 'clean':
input = super(MixBatchNorm2d, self).forward(input)
else:
assert self.batch_type == 'mix'
batch_size = input.shape[0]
# input0 = self.aux_bn(input[: batch_size // 2])
# input1 = super(MixBatchNorm2d, self).forward(input[batch_size // 2:])
input0 = super(MixBatchNorm2d, self).forward(input[:batch_size // 2])
input1 = self.aux_bn(input[batch_size // 2:])
input = torch.cat((input0, input1), 0)
return input
def to_status(m, status):
"""
change the status of batch norm layer
status can be 'clean', 'adv' or 'mix'
"""
if hasattr(m, 'batch_type'):
m.batch_type = status
to_clean_status = partial(to_status, status='clean')
to_adv_status = partial(to_status, status='adv')
to_mix_status = partial(to_status, status='mix')