From 5458301beacbe03a09ed75da8cc74f4ac8f8c537 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 18 Jun 2024 05:37:40 +0900 Subject: [PATCH] Support prefix for retrieval and reranking --- src/jmteb/embedders/base.py | 21 +++++++++--- src/jmteb/embedders/sbert_embedder.py | 37 +++++++++++++++++++-- src/jmteb/evaluators/reranking/evaluator.py | 9 +++++ src/jmteb/evaluators/retrieval/evaluator.py | 10 ++++++ tests/evaluator/test_reranking_evaluator.py | 27 ++++++++++++--- tests/evaluator/test_retrieval_evaluator.py | 32 +++++++++++++++--- 6 files changed, 121 insertions(+), 15 deletions(-) diff --git a/src/jmteb/embedders/base.py b/src/jmteb/embedders/base.py index 145f543..4276d3d 100644 --- a/src/jmteb/embedders/base.py +++ b/src/jmteb/embedders/base.py @@ -14,11 +14,12 @@ class TextEmbedder(ABC): The base class of text embedder. """ - def encode(self, text: str | list[str]) -> np.ndarray: + def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray: """Convert a text string or a list of texts to embedding. Args: text (str | list[str]): text string, or a list of texts. + prompt (str, optional): the prompt to use for encoding. Default to None. """ raise NotImplementedError @@ -31,7 +32,12 @@ def get_output_dim(self) -> int: raise NotImplementedError def _batch_encode_and_save_on_disk( - self, text_list: list[str], save_path: str | PathLike[str], batch_size: int = 64, dtype: str = "float32" + self, + text_list: list[str], + save_path: str | PathLike[str], + prompt: str | None = None, + batch_size: int = 64, + dtype: str = "float32", ) -> np.memmap: """ Encode a list of texts and save the embeddings on disk using memmap. @@ -39,6 +45,7 @@ def _batch_encode_and_save_on_disk( Args: text_list (list[str]): list of texts save_path (str): path to save the embeddings + prompt (str, optional): the prompt to use for encoding. Default to None. dtype (str, optional): data type. Defaults to "float32". batch_size (int): batch size. Defaults to 64. """ @@ -50,7 +57,7 @@ def _batch_encode_and_save_on_disk( with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar: for i in range(0, num_samples, batch_size): batch = text_list[i : i + batch_size] - batch_embeddings = self.encode(batch) + batch_embeddings = self.encode(batch, prompt=prompt) batch_embeddings = np.asarray(batch_embeddings, dtype=dtype) embeddings[i : i + batch_size] = batch_embeddings pbar.update(len(batch)) @@ -61,6 +68,7 @@ def _batch_encode_and_save_on_disk( def batch_encode_with_cache( self, text_list: list[str], + prompt: str | None = None, cache_path: str | PathLike[str] | None = None, overwrite_cache: bool = False, batch_size: int = 64, @@ -71,6 +79,7 @@ def batch_encode_with_cache( Args: text_list (list[str]): list of texts + prompt (str, optional): the prompt to use for encoding. Default to None. cache_path (str, optional): path to save the embeddings. Defaults to None. overwrite_cache (bool, optional): whether to overwrite the cache. Defaults to False. batch_size (int): batch size. Defaults to 64. @@ -79,12 +88,14 @@ def batch_encode_with_cache( if cache_path is None: logger.info("Encoding embeddings") - return self.encode(text_list).astype(dtype) + return self.encode(text_list, prompt=prompt).astype(dtype) if Path(cache_path).exists() and not overwrite_cache: logger.info(f"Loading embeddings from {cache_path}") return np.memmap(cache_path, dtype=dtype, mode="r", shape=(len(text_list), self.get_output_dim())) logger.info(f"Encoding and saving embeddings to {cache_path}") - embeddings = self._batch_encode_and_save_on_disk(text_list, cache_path, batch_size=batch_size, dtype=dtype) + embeddings = self._batch_encode_and_save_on_disk( + text_list, cache_path, prompt=prompt, batch_size=batch_size, dtype=dtype + ) return embeddings diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index 48ab984..6fbc48e 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -15,20 +15,53 @@ def __init__( batch_size: int = 32, device: str | None = None, normalize_embeddings: bool = False, + max_seq_length: int | None = None, + tokenizer_padding_side: str | None = None, + add_eos: bool = False, ) -> None: - self.model = SentenceTransformer(model_name_or_path) + self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True) + if max_seq_length: + self.model.max_seq_length = max_seq_length + if tokenizer_padding_side: + try: + self.model.tokenizer.padding_side = "right" + except AttributeError: + pass + self.batch_size = batch_size self.device = device self.normalize_embeddings = normalize_embeddings + self.max_seq_length = max_seq_length + self.tokenizer_padding_side = tokenizer_padding_side + self.add_eos = add_eos + + if self.max_seq_length: + self.model.max_seq_length = self.max_seq_length + if self.tokenizer_padding_side: + setattr(self.model.tokenizer, "padding_side", self.tokenizer_padding_side) - def encode(self, text: str | list[str]) -> np.ndarray: + def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray: + if self.add_eos: + text = self.add_eos_func(text) return self.model.encode( text, + prompt=prompt, convert_to_numpy=True, batch_size=self.batch_size, device=self.device, normalize_embeddings=self.normalize_embeddings, ) + def add_eos_func(self, text: str | list[str]) -> str | list[str]: + try: + eos_token = getattr(self.model.tokenizer, "eos_token") + except AttributeError: + return text + + if isinstance(text, str): + return text + eos_token + elif isinstance(text, list): + return [t + eos_token for t in text] + def get_output_dim(self) -> int: return self.model.get_sentence_embedding_dimension() diff --git a/src/jmteb/evaluators/reranking/evaluator.py b/src/jmteb/evaluators/reranking/evaluator.py index 4b71dfe..0089c0a 100644 --- a/src/jmteb/evaluators/reranking/evaluator.py +++ b/src/jmteb/evaluators/reranking/evaluator.py @@ -28,6 +28,8 @@ class RerankingEvaluator(EmbeddingEvaluator): test_query_dataset (RerankingQueryDataset): test query dataset used for computing the scores doc_dataset (RerankingDocDataset): document dataset ndcg_at_k (list[int] | None): top k documents to consider in NDCG (Normalized Documented Cumulative Gain). + query_prefix (str | None): prefix for queries. Defaults to None. + doc_prefix (str | None): prefix for documents. Defaults to None. """ def __init__( @@ -36,12 +38,16 @@ def __init__( test_query_dataset: RerankingQueryDataset, doc_dataset: RerankingDocDataset, ndcg_at_k: list[int] | None = None, + query_prefix: str | None = None, + doc_prefix: str | None = None, ) -> None: self.test_query_dataset = test_query_dataset self.val_query_dataset = val_query_dataset self.doc_dataset = doc_dataset self.ndcg_at_k = ndcg_at_k or [10, 20, 40] self.main_metric = f"ndcg@{self.ndcg_at_k[0]}" + self.query_prefix = query_prefix + self.doc_prefix = doc_prefix def __call__( self, @@ -54,6 +60,7 @@ def __call__( val_query_embeddings = model.batch_encode_with_cache( text_list=[item.query for item in self.val_query_dataset], + prompt=self.query_prefix, cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) @@ -62,11 +69,13 @@ def __call__( else: test_query_embeddings = model.batch_encode_with_cache( text_list=[item.query for item in self.test_query_dataset], + prompt=self.query_prefix, cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) doc_embeddings = model.batch_encode_with_cache( text_list=[item.text for item in self.doc_dataset], + prompt=self.doc_prefix, cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) diff --git a/src/jmteb/evaluators/retrieval/evaluator.py b/src/jmteb/evaluators/retrieval/evaluator.py index bc97e33..9af7af4 100644 --- a/src/jmteb/evaluators/retrieval/evaluator.py +++ b/src/jmteb/evaluators/retrieval/evaluator.py @@ -31,6 +31,8 @@ class RetrievalEvaluator(EmbeddingEvaluator): doc_chunk_size (int): The maximum size of corpus chunk. Smaller chunk requires less memory but lowers speed. ndcg_at_k (list[int] | None): top k documents to consider in NDCG (Normalized Documented Cumulative Gain). accuracy_at_k (list[int] | None): accuracy in top k hits. + query_prefix (str | None): prefix for queries. Defaults to None. + doc_prefix (str | None): prefix for documents. Defaults to None. """ def __init__( @@ -41,6 +43,8 @@ def __init__( doc_chunk_size: int = 1000000, accuracy_at_k: list[int] | None = None, ndcg_at_k: list[int] | None = None, + query_prefix: str | None = None, + doc_prefix: str | None = None, ) -> None: self.val_query_dataset = val_query_dataset self.test_query_dataset = test_query_dataset @@ -53,6 +57,9 @@ def __init__( self.max_top_k = max(sum([self.accuracy_at_k, self.ndcg_at_k], [])) self.main_metric = f"ndcg@{self.ndcg_at_k[0]}" + self.query_prefix = query_prefix + self.doc_prefix = doc_prefix + def __call__( self, model: TextEmbedder, @@ -64,6 +71,7 @@ def __call__( val_query_embeddings = model.batch_encode_with_cache( text_list=[item.query for item in self.val_query_dataset], + prompt=self.query_prefix, cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) @@ -72,12 +80,14 @@ def __call__( else: test_query_embeddings = model.batch_encode_with_cache( text_list=[item.query for item in self.test_query_dataset], + prompt=self.query_prefix, cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) doc_embeddings = model.batch_encode_with_cache( text_list=[item.text for item in self.doc_dataset], + prompt=self.doc_prefix, cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) diff --git a/tests/evaluator/test_reranking_evaluator.py b/tests/evaluator/test_reranking_evaluator.py index 0d903cb..ef847a9 100644 --- a/tests/evaluator/test_reranking_evaluator.py +++ b/tests/evaluator/test_reranking_evaluator.py @@ -12,11 +12,13 @@ EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_distance_metric"} EXPECTED_DIST_FUNC_NAMES = {"cosine_similarity", "euclidean_distance", "dot_score"} +QUERY_PREFIX = "クエリ: " +DOC_PREFIX = "ドキュメント: " class DummyDocDataset(RerankingDocDataset): - def __init__(self): - self._items = [RerankingDoc(id=str(i), text=f"dummy document {i}") for i in range(30)] + def __init__(self, prefix: str = ""): + self._items = [RerankingDoc(id=str(i), text=f"{prefix}dummy document {i}") for i in range(30)] def __len__(self): return len(self._items) @@ -26,9 +28,10 @@ def __getitem__(self, idx): class DummyQueryDataset(RerankingQueryDataset): - def __init__(self): + def __init__(self, prefix: str = ""): self._items = [ - RerankingQuery(query=f"dummy query {i}", retrieved_docs=[str(i)], relevance_scores=[1]) for i in range(10) + RerankingQuery(query=f"{prefix}dummy query {i}", retrieved_docs=[str(i)], relevance_scores=[1]) + for i in range(10) ] def __len__(self): @@ -57,6 +60,22 @@ def test_reranking_evaluator(embedder): assert any(score.startswith(metric) for metric in ["ndcg"]) +def test_reranking_evaluator_with_prefix(embedder): + evaluator_with_prefix = RerankingEvaluator( + val_query_dataset=DummyQueryDataset(), + test_query_dataset=DummyQueryDataset(), + doc_dataset=DummyDocDataset(), + query_prefix=QUERY_PREFIX, + doc_prefix=DOC_PREFIX, + ) + evaluator_with_manual_prefix = RerankingEvaluator( + val_query_dataset=DummyQueryDataset(prefix=QUERY_PREFIX), + test_query_dataset=DummyQueryDataset(prefix=QUERY_PREFIX), + doc_dataset=DummyDocDataset(prefix=DOC_PREFIX), + ) + assert evaluator_with_prefix(model=embedder) == evaluator_with_manual_prefix(model=embedder) + + def test_jsonl_reranking_datasets(): query = JsonlRerankingQueryDataset( filename="tests/test_data/dummy_reranking/val.jsonl", diff --git a/tests/evaluator/test_retrieval_evaluator.py b/tests/evaluator/test_retrieval_evaluator.py index 82b7944..fa52c52 100644 --- a/tests/evaluator/test_retrieval_evaluator.py +++ b/tests/evaluator/test_retrieval_evaluator.py @@ -12,11 +12,13 @@ EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_distance_metric"} EXPECTED_DIST_FUNC_NAMES = {"cosine_similarity", "euclidean_distance", "dot_score"} +QUERY_PREFIX = "クエリ: " +DOC_PREFIX = "ドキュメント: " class DummyDocDataset(RetrievalDocDataset): - def __init__(self): - self._items = [RetrievalDoc(id=str(i), text=f"dummy document {i}") for i in range(30)] + def __init__(self, prefix: str = ""): + self._items = [RetrievalDoc(id=str(i), text=f"{prefix}dummy document {i}") for i in range(30)] def __len__(self): return len(self._items) @@ -26,8 +28,8 @@ def __getitem__(self, idx): class DummyQueryDataset(RetrievalQueryDataset): - def __init__(self): - self._items = [RetrievalQuery(f"dummy query {i}", relevant_docs=[str(i)]) for i in range(10)] + def __init__(self, prefix: str = ""): + self._items = [RetrievalQuery(f"{prefix}dummy query {i}", relevant_docs=[str(i)]) for i in range(10)] def __len__(self): return len(self._items) @@ -58,6 +60,28 @@ def test_retrieval_evaluator(embedder): assert any(score.startswith(metric) for metric in ["accuracy", "mrr", "ndcg"]) +def test_retrieval_evaluator_with_prefix(embedder): + evaluator_with_prefix = RetrievalEvaluator( + val_query_dataset=DummyQueryDataset(), + test_query_dataset=DummyQueryDataset(), + doc_dataset=DummyDocDataset(), + query_prefix=QUERY_PREFIX, + doc_prefix=DOC_PREFIX, + accuracy_at_k=[1, 3, 5, 10], + ndcg_at_k=[1, 3, 5], + doc_chunk_size=3, + ) + evaluator_with_manual_prefix = RetrievalEvaluator( + val_query_dataset=DummyQueryDataset(prefix=QUERY_PREFIX), + test_query_dataset=DummyQueryDataset(prefix=QUERY_PREFIX), + doc_dataset=DummyDocDataset(prefix=DOC_PREFIX), + accuracy_at_k=[1, 3, 5, 10], + ndcg_at_k=[1, 3, 5], + doc_chunk_size=3, + ) + assert evaluator_with_prefix(model=embedder) == evaluator_with_manual_prefix(model=embedder) + + def test_if_chunking_does_not_change_result(embedder): evaluator1 = RetrievalEvaluator( val_query_dataset=DummyQueryDataset(),