Skip to content

Commit

Permalink
adding tests for embedding and similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
golubevtanya committed Oct 27, 2023
1 parent 81cef76 commit a5ecbd1
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
7 changes: 7 additions & 0 deletions src/backend/api/matching_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
import torch
from sklearn.metrics.pairwise import cosine_similarity


Expand Down Expand Up @@ -27,3 +28,9 @@ def get_free_text_match(
return 0

return cosine_similarity(candidate_embeddings, job_embeddings)[0][0]

if __name__=="__main__":
print(int(get_free_text_match(
torch.tensor([[1,0,0]]),
torch.tensor([[-1,0,0]]),
))==-1)
4 changes: 3 additions & 1 deletion src/backend/api/tokenization_n_embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from typing import List
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity


MODEL_NAME = "bert-base-uncased"
Expand Down Expand Up @@ -46,3 +45,6 @@ def generate_embeddings(text: str, model_name: str=MODEL_NAME) -> List[List]:
text_embeddings = text_outputs.last_hidden_state.mean(dim=1)
return text_embeddings

if __name__=="__main__":
s = ""
print(generate_embeddings(s)[0][0].item()==-0.00922924280166626)
35 changes: 35 additions & 0 deletions tests/test_embedding_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import sys
import os
current = os.path.dirname(os.path.realpath('tokenization_n_embedding.py'))
parent = os.path.dirname(current)
sys.path.append(parent)
from src.backend.api.tokenization_n_embedding import tokenize_text, generate_embeddings
from src.backend.api.matching_algorithm import get_free_text_match

def test_tokenize_text():
text_tokens_keys = set(['input_ids','token_type_ids','attention_mask'])
assert set(tokenize_text("test").keys()) == \
text_tokens_keys
assert set(tokenize_text("").keys()) == \
text_tokens_keys

def generate_embeddings():
assert len(generate_embeddings("any text")) == 1
assert len(generate_embeddings("any text")[0]) == 768
assert generate_embeddings("").dtype == torch.float32
assert generate_embeddings("")[0][0].item()==-0.00922924280166626

def get_free_text_match_text():
assert get_free_text_match(int(
torch.tensor([[1,2,3]]),
torch.tensor([[-1,-2,-3]])
)) == -1
assert get_free_text_match(int(
torch.tensor([[1,2,3]]),
torch.tensor([[1,2,3]])
)) == 1
assert get_free_text_match(int(
torch.tensor([[1,0,0]]),
torch.tensor([[0,1,0]]),
)) == 0

0 comments on commit a5ecbd1

Please sign in to comment.