Skip to content

Commit

Permalink
Improve test_lfw_vox_celeb_models.py (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
mamaheux authored Feb 28, 2023
1 parent 20732f7 commit a02fd7a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class AudioDescriptorEvaluation(RocDistancesThresholdsEvaluation):
def __init__(self, model, device, transforms, dataset_root, output_path):
super(AudioDescriptorEvaluation, self).__init__(output_path, thresholds=np.arange(0, 2, 0.001))
super(AudioDescriptorEvaluation, self).__init__(output_path, thresholds=np.arange(0, 2, 0.00001))

self._model = model
self._device = device
Expand Down
12 changes: 6 additions & 6 deletions tools/dnn_training/common/metrics/roc_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _calculate_eer(self, true_positive_rate_curve, false_positive_rate_curve):

return (false_negative_rate_curve[index] + false_positive_rate_curve[index]) / 2

def _save_roc_curve(self, true_positive_rate_curve, false_positive_rate_curve):
def _save_roc_curve(self, true_positive_rate_curve, false_positive_rate_curve, prefix=''):
fig = plt.figure(figsize=(5, 5), dpi=300)
ax1 = fig.add_subplot(111)

Expand All @@ -48,20 +48,20 @@ def _save_roc_curve(self, true_positive_rate_curve, false_positive_rate_curve):
ax1.set_xlabel(u'False positive rate')
ax1.set_ylabel(u'True positive rate')

fig.savefig(os.path.join(self._output_path, 'roc_curve.png'))
fig.savefig(os.path.join(self._output_path, prefix + 'roc_curve.png'))
plt.close(fig)

def _save_roc_curve_data(self, true_positive_rate_curve, false_positive_rate_curve, thresholds):
with open(os.path.join(self._output_path, 'roc_curve.json'), 'w') as file:
def _save_roc_curve_data(self, true_positive_rate_curve, false_positive_rate_curve, thresholds, prefix=''):
with open(os.path.join(self._output_path, prefix + 'roc_curve.json'), 'w') as file:
data = {
'true_positive_rate_curve': true_positive_rate_curve.tolist(),
'false_positive_rate_curve': false_positive_rate_curve.tolist(),
'thresholds': thresholds.tolist()
}
json.dump(data, file, indent=4, sort_keys=True)

def _save_performances(self, values_by_name):
with open(os.path.join(self._output_path, 'performance.json'), 'w') as file:
def _save_performances(self, values_by_name, prefix=''):
with open(os.path.join(self._output_path, prefix + 'performance.json'), 'w') as file:
json.dump(values_by_name, file, indent=4, sort_keys=True)


Expand Down
69 changes: 53 additions & 16 deletions tools/dnn_training/test_lfw_vox_celeb_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@


class LfwVoxCelebEvaluation(RocDistancesThresholdsEvaluation):
def __init__(self, lfw_dataset_root, vox_celeb_dataset_root, pairs_file, output_path,
def __init__(self, device, lfw_dataset_root, vox_celeb_dataset_root, pairs_file, output_path,
face_model, face_transforms, voice_model, voice_transforms):
super().__init__(output_path, thresholds=np.arange(0, 10, 0.0001))

super().__init__(output_path, thresholds=np.arange(0, 10, 0.00001))
self._device = device
self._lfw_dataset_root = lfw_dataset_root
self._vox_celeb_dataset_root = vox_celeb_dataset_root
self._pairs_file = pairs_file
Expand Down Expand Up @@ -76,25 +76,58 @@ def _read_pairs(self):

return pairs

def evaluate(self):
print('Calculate distances')
face_distances, voice_distances, face_voice_distances = self._calculate_distances()
is_same_person_target = self._get_is_same_person_target()

self._evaluate(face_distances, is_same_person_target, 'face_')
self._evaluate(voice_distances, is_same_person_target, 'voice_')
self._evaluate(face_voice_distances, is_same_person_target, 'face_voice_')

def _evaluate(self, distances, is_same_person_target, prefix):
best_accuracy, best_threshold, true_positive_rate_curve, false_positive_rate_curve, thresholds = \
self._calculate_accuracy_true_positive_rate_false_positive_rate(distances, is_same_person_target)
auc = self._calculate_auc(true_positive_rate_curve, false_positive_rate_curve)
eer = self._calculate_eer(true_positive_rate_curve, false_positive_rate_curve)

print(prefix)
print('Best accuracy: {}, threshold: {}, AUC: {}, EER: {}'.format(best_accuracy, best_threshold, auc, eer))
print()

self._save_roc_curve(true_positive_rate_curve, false_positive_rate_curve, prefix=prefix)
self._save_roc_curve_data(true_positive_rate_curve, false_positive_rate_curve, thresholds, prefix=prefix)
self._save_performances({
'best_accuracy': best_accuracy,
'best_threshold': best_threshold,
'auc': auc,
'eer': eer
}, prefix=prefix)

def _calculate_distances(self):
distances = []
face_distances = []
voice_distances = []
face_voice_distances = []

for voice_path_0, voice_path_1, face_path_0, face_path_1, _ in tqdm(self._pairs):
voice_sound_0 = self._load_voice_sound(voice_path_0)
voice_sound_1 = self._load_voice_sound(voice_path_1)
face_image_0 = self._load_face_image(face_path_0)
face_image_1 = self._load_face_image(face_path_1)
voice_sound_0 = self._load_voice_sound(voice_path_0).to(self._device)
voice_sound_1 = self._load_voice_sound(voice_path_1).to(self._device)
face_image_0 = self._load_face_image(face_path_0).to(self._device)
face_image_1 = self._load_face_image(face_path_1).to(self._device)

voice_descriptor_0 = self._voice_model(voice_sound_0)
voice_descriptor_1 = self._voice_model(voice_sound_1)
face_descriptors = self._face_model(torch.stack((face_image_0, face_image_1)))

descriptor_0 = torch.cat((voice_descriptor_0[0], face_descriptors[0]))
descriptor_1 = torch.cat((voice_descriptor_1[0], face_descriptors[1]))
distance = torch.dist(descriptor_0, descriptor_1, p=2).item()
distances.append(distance)
face_distance = torch.dist(face_descriptors[0], face_descriptors[1], p=2).item()
voice_distance = torch.dist(voice_descriptor_0[0], voice_descriptor_1[0], p=2).item()
face_voice_distance = torch.dist(torch.cat((voice_descriptor_0[0], face_descriptors[0])),
torch.cat((voice_descriptor_1[0], face_descriptors[1])), p=2).item()
face_distances.append(face_distance)
voice_distances.append(voice_distance)
face_voice_distances.append(face_voice_distance)

return torch.tensor(distances)
return torch.tensor(face_distances), torch.tensor(voice_distances), torch.tensor(face_voice_distances)

def _load_voice_sound(self, path):
waveform, sample_rate = torchaudio.load(path)
Expand Down Expand Up @@ -122,6 +155,7 @@ def _get_is_same_person_target(self):

def main():
parser = argparse.ArgumentParser(description='Test exported face descriptor extractor')
parser.add_argument('--use_gpu', action='store_true')

parser.add_argument('--lfw_dataset_root', type=str, help='Choose the lfw dataset root path', required=True)
parser.add_argument('--vox_celeb_dataset_root', type=str, help='Choose the vox celeb dataset root path',
Expand Down Expand Up @@ -152,21 +186,24 @@ def main():

args = parser.parse_args()

face_model = create_face_model(args.face_embedding_size)
device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu')

face_model = create_face_model(args.face_embedding_size).to(device)
load_checkpoint(face_model, args.face_model_checkpoint, keys_to_remove=['_classifier._weight'])
face_model.eval()
face_transforms = create_validation_image_transform()

voice_model = create_voice_model(args.voice_backbone_type, args.voice_n_features, args.voice_embedding_size,
pooling_layer=args.voice_pooling_layer)
pooling_layer=args.voice_pooling_layer).to(device)
load_checkpoint(voice_model, args.voice_model_checkpoint, keys_to_remove=['_classifier._weight'])
voice_model.eval()
voice_transforms = AudioDescriptorTestTransforms(waveform_size=args.voice_waveform_size,
n_features=args.voice_n_features,
n_fft=args.voice_n_fft,
audio_transform_type=args.voice_audio_transform_type)

evaluation = LfwVoxCelebEvaluation(args.lfw_dataset_root,
evaluation = LfwVoxCelebEvaluation(device,
args.lfw_dataset_root,
args.vox_celeb_dataset_root,
args.pairs_file,
args.output_path,
Expand Down

0 comments on commit a02fd7a

Please sign in to comment.