-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
65 lines (53 loc) · 1.76 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import hydra
import numpy as np
import torch
from utils.utils import (
OneToManyMetrics,
OneToOneMetrics,
calculate_metrics,
dataset_setup,
model_setup,
one_to_many_comparison,
one_to_one_comparison,
)
from utils.utils_plots import plot_results
@hydra.main(version_base="1.1", config_path="config", config_name="config_test")
def main(cfg):
# Check for GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model
model = model_setup(device, cfg.save_path)
model.eval()
# Prepare test dataset and dataloader
test_dataloader = dataset_setup(
cfg.data_dir, model=model, device=device, train=False
)
# Perform one-to-many comparison
one_to_many_genuine_distances, one_to_many_impostor_distances = (
one_to_many_comparison(test_dataloader)
)
# Perform one-to-one comparison
one_to_one_genuine_distances, one_to_one_impostor_distances = one_to_one_comparison(
test_dataloader
)
# Calculate metrics for one-to-one comparison
thresholds = np.logspace(-7, 0, 10)
one_to_one_metrics = OneToOneMetrics(
*calculate_metrics(
thresholds, one_to_one_impostor_distances, one_to_one_genuine_distances
),
one_to_one_impostor_distances,
one_to_one_genuine_distances
)
# Calculate metrics for one-to-many comparison
one_to_many_metrics = OneToManyMetrics(
*calculate_metrics(
thresholds, one_to_many_impostor_distances, one_to_many_genuine_distances
),
one_to_many_impostor_distances,
one_to_many_genuine_distances
)
# Plot results
plot_results(thresholds, one_to_one_metrics, one_to_many_metrics)
if __name__ == "__main__":
main()