-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidation.py
50 lines (39 loc) · 1.38 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
This module validates and calculates the accuracy of a model on MNIST validation data.
"""
import time
import torch
from tqdm import tqdm
from data import load_data
def validate(model, device, n_total=2000):
"""
Validate the model on the validation data.
:param n_total: the number of images to validate.
:param device: cuda or cpu.
:param model: Model to validate.
:return: Tuple of accuracy, loss, and average inference time (ms).
"""
dataloaders = load_data(batch_size=1, num_workers=0)
model.eval()
model.to(device)
correct = 0
total = 0
running_loss = 0.0
start_time = time.time()
i_data = 0
with torch.no_grad():
for data in tqdm(dataloaders['val'], total=n_total, desc='Validating model', unit=' image'):
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
running_loss += torch.nn.CrossEntropyLoss()(outputs, labels).item()
i_data += 1
if i_data > n_total:
break
elapsed_time = time.time() - start_time
accuracy = correct / total
loss = running_loss / total
avg_inference_time = elapsed_time / total
return accuracy, loss, avg_inference_time * 1000