-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathFrame_test.py
65 lines (52 loc) · 2.51 KB
/
Frame_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from components import networks3D
from components.performance_metric import mean_absolute_error, peak_signal_to_noise_ratio, structural_similarity_index
from options.test_options import TestOptions
from utils.UnpairedDataset import UnpairedDataset
def evaluate_generator(generator, test_loader, netG):
"""Evaluate a generator.
Parameters:
generator - - : (nn.Module) neural network generating PET images
train_loader - - : (dataloader) the training loader
test_loader - - : (dataloader) the testing loader
Returns:
df - - : (dataframe) the dataframe of the different Sets
"""
res_test = []
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
generator.eval()
with torch.no_grad():
for i, batch in enumerate(tqdm(test_loader)):
# Inputs MRI and PET
real_mri = batch[0].type(Tensor)
real_pet = batch[1].type(Tensor)
if netG == 'ShareSynNet':
fake_pet = generator(real_mri, alpha=0.)
else:
fake_pet = generator(real_mri)
mae = mean_absolute_error(real_pet, fake_pet).item()
psnr = peak_signal_to_noise_ratio(real_pet, fake_pet).item()
ssim = structural_similarity_index(real_pet, fake_pet).item()
res_test.append([mae, psnr, ssim])
df = pd.DataFrame([
pd.DataFrame(res_test, columns=['MAE', 'PSNR', 'SSIM']).mean().squeeze()
], index=['Test set']).T
return df
if __name__ == '__main__':
opt = TestOptions().parse()
netG = networks3D.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)
print('netG:\ttype:{}\tload_path:{}'.format(opt.netG, opt.load_path))
test_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="test", load_size=opt.load_size,
crop_size=opt.crop_size)
test_loader = DataLoader(test_set, batch_size=1, shuffle=True, num_workers=opt.workers,
pin_memory=True) # Here are then fed to the network with a defined batch siz
print('lenght test list:', len(test_set))
state_dict = torch.load(opt.load_path)
netG.load_state_dict(state_dict)
test_df = evaluate_generator(netG, test_loader, opt.netG)
print(test_df)