Skip to content

Commit

Permalink
Support prefix for retrieval and reranking
Browse files Browse the repository at this point in the history
  • Loading branch information
lsz05 committed Jun 17, 2024
1 parent 277471a commit 5458301
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 15 deletions.
21 changes: 16 additions & 5 deletions src/jmteb/embedders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ class TextEmbedder(ABC):
The base class of text embedder.
"""

def encode(self, text: str | list[str]) -> np.ndarray:
def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray:
"""Convert a text string or a list of texts to embedding.
Args:
text (str | list[str]): text string, or a list of texts.
prompt (str, optional): the prompt to use for encoding. Default to None.
"""
raise NotImplementedError

Expand All @@ -31,14 +32,20 @@ def get_output_dim(self) -> int:
raise NotImplementedError

def _batch_encode_and_save_on_disk(
self, text_list: list[str], save_path: str | PathLike[str], batch_size: int = 64, dtype: str = "float32"
self,
text_list: list[str],
save_path: str | PathLike[str],
prompt: str | None = None,
batch_size: int = 64,
dtype: str = "float32",
) -> np.memmap:
"""
Encode a list of texts and save the embeddings on disk using memmap.
Args:
text_list (list[str]): list of texts
save_path (str): path to save the embeddings
prompt (str, optional): the prompt to use for encoding. Default to None.
dtype (str, optional): data type. Defaults to "float32".
batch_size (int): batch size. Defaults to 64.
"""
Expand All @@ -50,7 +57,7 @@ def _batch_encode_and_save_on_disk(
with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar:
for i in range(0, num_samples, batch_size):
batch = text_list[i : i + batch_size]
batch_embeddings = self.encode(batch)
batch_embeddings = self.encode(batch, prompt=prompt)
batch_embeddings = np.asarray(batch_embeddings, dtype=dtype)
embeddings[i : i + batch_size] = batch_embeddings
pbar.update(len(batch))
Expand All @@ -61,6 +68,7 @@ def _batch_encode_and_save_on_disk(
def batch_encode_with_cache(
self,
text_list: list[str],
prompt: str | None = None,
cache_path: str | PathLike[str] | None = None,
overwrite_cache: bool = False,
batch_size: int = 64,
Expand All @@ -71,6 +79,7 @@ def batch_encode_with_cache(
Args:
text_list (list[str]): list of texts
prompt (str, optional): the prompt to use for encoding. Default to None.
cache_path (str, optional): path to save the embeddings. Defaults to None.
overwrite_cache (bool, optional): whether to overwrite the cache. Defaults to False.
batch_size (int): batch size. Defaults to 64.
Expand All @@ -79,12 +88,14 @@ def batch_encode_with_cache(

if cache_path is None:
logger.info("Encoding embeddings")
return self.encode(text_list).astype(dtype)
return self.encode(text_list, prompt=prompt).astype(dtype)

if Path(cache_path).exists() and not overwrite_cache:
logger.info(f"Loading embeddings from {cache_path}")
return np.memmap(cache_path, dtype=dtype, mode="r", shape=(len(text_list), self.get_output_dim()))

logger.info(f"Encoding and saving embeddings to {cache_path}")
embeddings = self._batch_encode_and_save_on_disk(text_list, cache_path, batch_size=batch_size, dtype=dtype)
embeddings = self._batch_encode_and_save_on_disk(
text_list, cache_path, prompt=prompt, batch_size=batch_size, dtype=dtype
)
return embeddings
37 changes: 35 additions & 2 deletions src/jmteb/embedders/sbert_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,53 @@ def __init__(
batch_size: int = 32,
device: str | None = None,
normalize_embeddings: bool = False,
max_seq_length: int | None = None,
tokenizer_padding_side: str | None = None,
add_eos: bool = False,
) -> None:
self.model = SentenceTransformer(model_name_or_path)
self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True)
if max_seq_length:
self.model.max_seq_length = max_seq_length
if tokenizer_padding_side:
try:
self.model.tokenizer.padding_side = "right"
except AttributeError:
pass

self.batch_size = batch_size
self.device = device
self.normalize_embeddings = normalize_embeddings
self.max_seq_length = max_seq_length
self.tokenizer_padding_side = tokenizer_padding_side
self.add_eos = add_eos

if self.max_seq_length:
self.model.max_seq_length = self.max_seq_length
if self.tokenizer_padding_side:
setattr(self.model.tokenizer, "padding_side", self.tokenizer_padding_side)

def encode(self, text: str | list[str]) -> np.ndarray:
def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray:
if self.add_eos:
text = self.add_eos_func(text)
return self.model.encode(
text,
prompt=prompt,
convert_to_numpy=True,
batch_size=self.batch_size,
device=self.device,
normalize_embeddings=self.normalize_embeddings,
)

def add_eos_func(self, text: str | list[str]) -> str | list[str]:
try:
eos_token = getattr(self.model.tokenizer, "eos_token")
except AttributeError:
return text

if isinstance(text, str):
return text + eos_token
elif isinstance(text, list):
return [t + eos_token for t in text]

def get_output_dim(self) -> int:
return self.model.get_sentence_embedding_dimension()
9 changes: 9 additions & 0 deletions src/jmteb/evaluators/reranking/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class RerankingEvaluator(EmbeddingEvaluator):
test_query_dataset (RerankingQueryDataset): test query dataset used for computing the scores
doc_dataset (RerankingDocDataset): document dataset
ndcg_at_k (list[int] | None): top k documents to consider in NDCG (Normalized Documented Cumulative Gain).
query_prefix (str | None): prefix for queries. Defaults to None.
doc_prefix (str | None): prefix for documents. Defaults to None.
"""

def __init__(
Expand All @@ -36,12 +38,16 @@ def __init__(
test_query_dataset: RerankingQueryDataset,
doc_dataset: RerankingDocDataset,
ndcg_at_k: list[int] | None = None,
query_prefix: str | None = None,
doc_prefix: str | None = None,
) -> None:
self.test_query_dataset = test_query_dataset
self.val_query_dataset = val_query_dataset
self.doc_dataset = doc_dataset
self.ndcg_at_k = ndcg_at_k or [10, 20, 40]
self.main_metric = f"ndcg@{self.ndcg_at_k[0]}"
self.query_prefix = query_prefix
self.doc_prefix = doc_prefix

def __call__(
self,
Expand All @@ -54,6 +60,7 @@ def __call__(

val_query_embeddings = model.batch_encode_with_cache(
text_list=[item.query for item in self.val_query_dataset],
prompt=self.query_prefix,
cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand All @@ -62,11 +69,13 @@ def __call__(
else:
test_query_embeddings = model.batch_encode_with_cache(
text_list=[item.query for item in self.test_query_dataset],
prompt=self.query_prefix,
cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
doc_embeddings = model.batch_encode_with_cache(
text_list=[item.text for item in self.doc_dataset],
prompt=self.doc_prefix,
cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand Down
10 changes: 10 additions & 0 deletions src/jmteb/evaluators/retrieval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class RetrievalEvaluator(EmbeddingEvaluator):
doc_chunk_size (int): The maximum size of corpus chunk. Smaller chunk requires less memory but lowers speed.
ndcg_at_k (list[int] | None): top k documents to consider in NDCG (Normalized Documented Cumulative Gain).
accuracy_at_k (list[int] | None): accuracy in top k hits.
query_prefix (str | None): prefix for queries. Defaults to None.
doc_prefix (str | None): prefix for documents. Defaults to None.
"""

def __init__(
Expand All @@ -41,6 +43,8 @@ def __init__(
doc_chunk_size: int = 1000000,
accuracy_at_k: list[int] | None = None,
ndcg_at_k: list[int] | None = None,
query_prefix: str | None = None,
doc_prefix: str | None = None,
) -> None:
self.val_query_dataset = val_query_dataset
self.test_query_dataset = test_query_dataset
Expand All @@ -53,6 +57,9 @@ def __init__(
self.max_top_k = max(sum([self.accuracy_at_k, self.ndcg_at_k], []))
self.main_metric = f"ndcg@{self.ndcg_at_k[0]}"

self.query_prefix = query_prefix
self.doc_prefix = doc_prefix

def __call__(
self,
model: TextEmbedder,
Expand All @@ -64,6 +71,7 @@ def __call__(

val_query_embeddings = model.batch_encode_with_cache(
text_list=[item.query for item in self.val_query_dataset],
prompt=self.query_prefix,
cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand All @@ -72,12 +80,14 @@ def __call__(
else:
test_query_embeddings = model.batch_encode_with_cache(
text_list=[item.query for item in self.test_query_dataset],
prompt=self.query_prefix,
cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)

doc_embeddings = model.batch_encode_with_cache(
text_list=[item.text for item in self.doc_dataset],
prompt=self.doc_prefix,
cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand Down
27 changes: 23 additions & 4 deletions tests/evaluator/test_reranking_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@

EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_distance_metric"}
EXPECTED_DIST_FUNC_NAMES = {"cosine_similarity", "euclidean_distance", "dot_score"}
QUERY_PREFIX = "クエリ: "
DOC_PREFIX = "ドキュメント: "


class DummyDocDataset(RerankingDocDataset):
def __init__(self):
self._items = [RerankingDoc(id=str(i), text=f"dummy document {i}") for i in range(30)]
def __init__(self, prefix: str = ""):
self._items = [RerankingDoc(id=str(i), text=f"{prefix}dummy document {i}") for i in range(30)]

def __len__(self):
return len(self._items)
Expand All @@ -26,9 +28,10 @@ def __getitem__(self, idx):


class DummyQueryDataset(RerankingQueryDataset):
def __init__(self):
def __init__(self, prefix: str = ""):
self._items = [
RerankingQuery(query=f"dummy query {i}", retrieved_docs=[str(i)], relevance_scores=[1]) for i in range(10)
RerankingQuery(query=f"{prefix}dummy query {i}", retrieved_docs=[str(i)], relevance_scores=[1])
for i in range(10)
]

def __len__(self):
Expand Down Expand Up @@ -57,6 +60,22 @@ def test_reranking_evaluator(embedder):
assert any(score.startswith(metric) for metric in ["ndcg"])


def test_reranking_evaluator_with_prefix(embedder):
evaluator_with_prefix = RerankingEvaluator(
val_query_dataset=DummyQueryDataset(),
test_query_dataset=DummyQueryDataset(),
doc_dataset=DummyDocDataset(),
query_prefix=QUERY_PREFIX,
doc_prefix=DOC_PREFIX,
)
evaluator_with_manual_prefix = RerankingEvaluator(
val_query_dataset=DummyQueryDataset(prefix=QUERY_PREFIX),
test_query_dataset=DummyQueryDataset(prefix=QUERY_PREFIX),
doc_dataset=DummyDocDataset(prefix=DOC_PREFIX),
)
assert evaluator_with_prefix(model=embedder) == evaluator_with_manual_prefix(model=embedder)


def test_jsonl_reranking_datasets():
query = JsonlRerankingQueryDataset(
filename="tests/test_data/dummy_reranking/val.jsonl",
Expand Down
32 changes: 28 additions & 4 deletions tests/evaluator/test_retrieval_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@

EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_distance_metric"}
EXPECTED_DIST_FUNC_NAMES = {"cosine_similarity", "euclidean_distance", "dot_score"}
QUERY_PREFIX = "クエリ: "
DOC_PREFIX = "ドキュメント: "


class DummyDocDataset(RetrievalDocDataset):
def __init__(self):
self._items = [RetrievalDoc(id=str(i), text=f"dummy document {i}") for i in range(30)]
def __init__(self, prefix: str = ""):
self._items = [RetrievalDoc(id=str(i), text=f"{prefix}dummy document {i}") for i in range(30)]

def __len__(self):
return len(self._items)
Expand All @@ -26,8 +28,8 @@ def __getitem__(self, idx):


class DummyQueryDataset(RetrievalQueryDataset):
def __init__(self):
self._items = [RetrievalQuery(f"dummy query {i}", relevant_docs=[str(i)]) for i in range(10)]
def __init__(self, prefix: str = ""):
self._items = [RetrievalQuery(f"{prefix}dummy query {i}", relevant_docs=[str(i)]) for i in range(10)]

def __len__(self):
return len(self._items)
Expand Down Expand Up @@ -58,6 +60,28 @@ def test_retrieval_evaluator(embedder):
assert any(score.startswith(metric) for metric in ["accuracy", "mrr", "ndcg"])


def test_retrieval_evaluator_with_prefix(embedder):
evaluator_with_prefix = RetrievalEvaluator(
val_query_dataset=DummyQueryDataset(),
test_query_dataset=DummyQueryDataset(),
doc_dataset=DummyDocDataset(),
query_prefix=QUERY_PREFIX,
doc_prefix=DOC_PREFIX,
accuracy_at_k=[1, 3, 5, 10],
ndcg_at_k=[1, 3, 5],
doc_chunk_size=3,
)
evaluator_with_manual_prefix = RetrievalEvaluator(
val_query_dataset=DummyQueryDataset(prefix=QUERY_PREFIX),
test_query_dataset=DummyQueryDataset(prefix=QUERY_PREFIX),
doc_dataset=DummyDocDataset(prefix=DOC_PREFIX),
accuracy_at_k=[1, 3, 5, 10],
ndcg_at_k=[1, 3, 5],
doc_chunk_size=3,
)
assert evaluator_with_prefix(model=embedder) == evaluator_with_manual_prefix(model=embedder)


def test_if_chunking_does_not_change_result(embedder):
evaluator1 = RetrievalEvaluator(
val_query_dataset=DummyQueryDataset(),
Expand Down

0 comments on commit 5458301

Please sign in to comment.