-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
66 lines (59 loc) · 2.32 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
'''
test one pretrained model
run testset
store result in csv file
python test.py --model xxx -file-name xxx.pkl
'''
import argparse
import os
import torch
import timm
from torch.utils.data import DataLoader
from dataset import dataset
import pandas as pd
import torch.nn.functional as F
# argument parser
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='resnet18', help='model')
parser.add_argument('--gpu', default=0, type=int, help='gpu')
parser.add_argument('--batch-size', default=32, type=int, help='batch-size')
parser.add_argument('--load-dir', default='checkpoints', type=str, help='where to load model')
parser.add_argument('--file-name', default='kfold_0.pkl', type=str, help='file name')
parser.add_argument('--save-dir', default='results', type=str, help='where to save csv file')
args = parser.parse_args()
# backbone network
if args.model == 'resnet18':
net = timm.create_model('resnet18', pretrained=True, num_classes=3).to(args.gpu)
elif args.model == 'resnet50d':
net = timm.create_model('resnet50d', pretrained=True, num_classes=3).to(args.gpu)
# load pretrained model
loadpath = os.path.join(args.load_dir,args.model,args.file_name)
ckpt = torch.load(loadpath)
state_dict = ckpt['state']
kappa = ckpt['kappa']
epoch = ckpt['epoch']
if epoch == -1:
print(f'Loading model from {loadpath} which reaches kappa {kappa} in cross validation')
else:
print(f'Loading model from {loadpath} which reaches kappa {kappa} at epoch {epoch}...')
net.load_state_dict(state_dict)
net.eval()
# dataset
testset = dataset(test=True)
testloader = DataLoader(testset, shuffle=False, batch_size=args.batch_size, num_workers=4, pin_memory=True)
savepath = os.path.join(args.save_dir,args.model,args.file_name[:-3]+'csv')
newflag = True
with torch.no_grad():
print('testing...')
for img, label, name in testloader:
img = img.to(args.gpu)
label_pred = net(img).cpu()
classList = label_pred.max(1)[1]
probs = F.softmax(label_pred,dim=1)
dataframe = pd.DataFrame({'case':name,'class':classList,'P0':probs[:,0],'P1':probs[:,1],'P2':probs[:,2]})
if newflag:
dataframe.to_csv(savepath, header=True, index=None)
newflag = False
else:
dataframe.to_csv(savepath, mode='a', header=False, index=None)
print(f'Saving results to {savepath}')