-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
executable file
·105 lines (90 loc) · 3.86 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np
import os
import torch
from data.dataloader import DataLoader
import matplotlib.pyplot as plt
from pandas import *
from tqdm import tqdm
from model.models_factory import ModelsFactory
from options.test_options import TestOptions
class Test(object):
def __init__(self):
self._opt = TestOptions().parse()
self.data_dir = "./dataset/AffectNet/"
self.model_path = './checkpoints/OAENet_ck+_ckpt.pth'
self.phase = 'val'
self.class_num = 7
self.model = ModelsFactory().get_by_name(self._opt.model_name, self._opt).model
self.model.load_state_dict(torch.load(self.model_path))
self.dataloader = DataLoader(self._opt, train=False).load_data()['val']
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.test()
def test(self):
print(os.path.join(self.model_path))
self.model.eval()
running_corrects = 0
data_num = len(self.dataloader.dataset)
print("data_num: ", data_num)
step_num = 0
confus_list = []
for inputs, w_mask, mask_seg, labels in tqdm(self.dataloader):
step_num += 1
data_list = []
inputs = inputs.to(self.device)
w_mask = w_mask.to(self.device)
mask_seg = mask_seg.to(self.device)
labels = labels.to(self.device)
self.model.to(self.device)
outputs = self.model((inputs, w_mask))
preds = torch.max(outputs, 1)[1]
running_corrects += torch.sum(preds == labels.data)
for i in range(self.class_num):
tmp_list = []
for j in range(self.class_num):
row_value = np.sum(np.array((labels.to('cpu').data == i) & (preds.to('cpu') == j)))
tmp_list.append(row_value)
data_list.append(tmp_list)
confus_list.append(data_list)
print("confus_list:" + str(len(confus_list)) + '\n')
result_confu = np.zeros((self.class_num, self.class_num))
for i in range(len(confus_list)):
arr_1 = np.array(confus_list[i])
result_confu += arr_1
print(result_confu)
tmp_value = result_confu.tolist()
row_value = list(map(sum, tmp_value))
for i in range(len(result_confu)):
result_confu[i] = result_confu[i] / row_value[i] * 100
print_result_confu = np.around(result_confu, decimals=2)
print(print_result_confu)
epoch_acc = running_corrects.double() / data_num
print('{} Acc: {:.4f}'.format(self.phase, epoch_acc))
@staticmethod
def show_batch(inputs, labels):
plt.figure()
print('input data shape: ', inputs.shape)
print("inputs data: ", inputs[0])
input_1 = inputs.detach().cpu().numpy()
label_1 = labels.detach().cpu().numpy()
for i in range(8):
plt.subplot(2, 4, i + 1)
plt.imshow(input_1[i].transpose([1, 2, 0]))
print(label_1[i])
plt.savefig('image.jpg')
@staticmethod
def precise_visual(confus_list, eva_dict, table_name):
idx = ['Neutral', 'Happy', 'Sad', 'Surprise', 'Fear', 'Disgust', 'Anger']
df = DataFrame(confus_list, index=idx,
columns=['Neutral', 'Happy', 'Sad', 'Surprise', 'Fear', 'Disgust', 'Anger'])
vals = np.around(df.values, 2)
fig = plt.figure(1, figsize=(15, 5))
ax = fig.add_subplot(111, frameon=False, xticks=[], yticks=[])
the_table = plt.table(cellText=vals, rowLabels=df.index, colLabels=df.columns,
colWidths=[0.1] * vals.shape[1], loc='center', cellLoc='center')
the_table.set_fontsize(20)
the_table.scale(2, 2.2)
plt.savefig(table_name)
if __name__ == '__main__':
print('.' * 50)
Test()