-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
47 lines (37 loc) · 1.8 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn.functional as F
def rgb_to_y(x):
rgb_to_grey = torch.tensor([0.256789, 0.504129, 0.097906], dtype=x.dtype, device=x.device).view(1, -1, 1, 1)
return torch.sum(x * rgb_to_grey, dim=1, keepdim=True).add(16.0)
def psnr(x, y, data_range=255.0):
x, y = x / data_range, y / data_range
mse = torch.mean((x - y) ** 2)
score = - 10 * torch.log10(mse)
return score
def ssim(x, y, kernel_size=11, kernel_sigma=1.5, data_range=255.0, k1=0.01, k2=0.03):
x, y = x / data_range, y / data_range
# average pool image if the size is large enough
f = max(1, round(min(x.size()[-2:]) / 256))
if f > 1:
x, y = F.avg_pool2d(x, kernel_size=f), F.avg_pool2d(y, kernel_size=f)
# gaussian filter
coords = torch.arange(kernel_size, dtype=x.dtype, device=x.device)
coords -= (kernel_size - 1) / 2.0
g = coords ** 2
g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * kernel_sigma ** 2)).exp()
g /= g.sum()
kernel = g.unsqueeze(0).repeat(x.size(1), 1, 1, 1)
# compute
c1, c2 = k1 ** 2, k2 ** 2
n_channels = x.size(1)
mu_x = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels)
mu_y = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels)
mu_xx, mu_yy, mu_xy = mu_x ** 2, mu_y ** 2, mu_x * mu_y
sigma_xx = F.conv2d(x ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx
sigma_yy = F.conv2d(y ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy
sigma_xy = F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xy
# contrast sensitivity (CS) with alpha = beta = gamma = 1.
cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2)
# structural similarity (SSIM)
ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs
return ss.mean()