-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlab2.py
53 lines (39 loc) · 1.33 KB
/
lab2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
class CosineConv2D(nn.Module):
def __init__(self, in_ch, out_ch, k,
stride=1, padding=0, dilation=1, groups=1):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, k, stride, padding, dilation, groups, bias=False)
self.conv1 = nn.Conv2d(in_ch, 1, k, stride, padding, dilation, groups, bias=False)
self.conv1.weight = nn.Parameter(torch.ones((1, in_ch, k, k)))
self.conv1.weight.requires_grad = False
def get_wnorm(self):
#[Co, Ci, k, k]
Co, Ci, k1, k2 = self.conv.weight.shape
weight = self.conv.weight.view(Co, -1)
return torch.norm(weight, p=2, dim=1)
def get_xnorm(self, x):
x = torch.square(x)
x = self.conv1(x)
x = torch.sqrt(x)
return x
def forward(self, x):
#[B, Co, Ho, Wo]
dot = self.conv(x)
#[Co]
w_norm = self.get_wnorm()
#[1, Co, 1, 1]
w_norm = w_norm.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
#[B, 1, Ho, Wo]
x_norm = self.get_xnorm(x)
return dot / (w_norm * x_norm)
conv = CosineConv2D(3, 5, 3)
#t = conv.get_wnorm()
#print(t.shape)
#x = torch.ones((1, 3, 5, 5))
#t = conv.get_xnorm(x)
#print(t.shape)
x = torch.ones((1, 3, 5, 5))
t = conv(x)
print(t.shape)