From 4ffd114281e0e200a024ed69afeb919765e5161b Mon Sep 17 00:00:00 2001 From: akiFQC Date: Wed, 7 Aug 2024 15:25:10 +0900 Subject: [PATCH 01/16] use sbert embedder with encode_multi_process --- src/jmteb/embedders/base.py | 7 +-- src/jmteb/embedders/sbert_embedder.py | 91 ++++++++++++++++++++++----- 2 files changed, 77 insertions(+), 21 deletions(-) diff --git a/src/jmteb/embedders/base.py b/src/jmteb/embedders/base.py index bede32b..afefec1 100644 --- a/src/jmteb/embedders/base.py +++ b/src/jmteb/embedders/base.py @@ -17,6 +17,7 @@ class TextEmbedder(ABC): convert_to_tensor: bool convert_to_numpy: bool + _chunk_size: int = 262144 # 2^18 def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray | torch.Tensor: """Convert a text string or a list of texts to embedding. @@ -40,7 +41,7 @@ def _batch_encode_and_save_on_disk( text_list: list[str], save_path: str | PathLike[str], prefix: str | None = None, - batch_size: int = 64, + batch_size: int = 262144, dtype: str = "float32", ) -> np.memmap | torch.Tensor: """ @@ -81,7 +82,6 @@ def batch_encode_with_cache( prefix: str | None = None, cache_path: str | PathLike[str] | None = None, overwrite_cache: bool = False, - batch_size: int = 64, dtype: str = "float32", ) -> np.ndarray | torch.Tensor: """ @@ -92,7 +92,6 @@ def batch_encode_with_cache( prefix (str, optional): the prefix to use for encoding. Default to None. cache_path (str, optional): path to save the embeddings. Defaults to None. overwrite_cache (bool, optional): whether to overwrite the cache. Defaults to False. - batch_size (int): batch size. Defaults to 64. dtype (str, optional): data type. Defaults to "float32". """ @@ -106,7 +105,7 @@ def batch_encode_with_cache( logger.info(f"Encoding and saving embeddings to {cache_path}") embeddings = self._batch_encode_and_save_on_disk( - text_list, cache_path, prefix=prefix, batch_size=batch_size, dtype=dtype + text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype ) return embeddings diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index 0188e7d..cfea3a3 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -1,11 +1,32 @@ from __future__ import annotations +from contextlib import contextmanager +from os import PathLike +from pathlib import Path +from typing import Optional + import numpy as np +import torch +import tqdm +from loguru import logger from sentence_transformers import SentenceTransformer from jmteb.embedders.base import TextEmbedder +@contextmanager +def sbert_multi_proc_pool(sbert_model: SentenceTransformer, target_devices: Optional[list[str]] = None): + pool = sbert_model.start_multi_process_pool(target_devices=target_devices) + logger.info("pool of encoding processing: ") + for k, v in pool.items(): + logger.info(f" {k}: {v}") + try: + yield pool + finally: + logger.info("stop pool") + sbert_model.stop_multi_process_pool(pool) + + class SentenceBertEmbedder(TextEmbedder): """SentenceBERT embedder.""" @@ -37,24 +58,60 @@ def __init__( self.normalize_embeddings = normalize_embeddings self.max_seq_length = getattr(self.model, "max_seq_length", None) self.add_eos = add_eos + self.set_output_numpy() + self.model.eval() - if "torch_dtype" in model_kwargs: - self.set_output_tensor() - else: - self.set_output_numpy() - - def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: - if self.add_eos: - text = self._add_eos_func(text) - return self.model.encode( - text, - prompt=prefix, - convert_to_numpy=self.convert_to_numpy, - convert_to_tensor=self.convert_to_tensor, - batch_size=self.batch_size, - device=self.device, - normalize_embeddings=self.normalize_embeddings, - ) + # override + def _batch_encode_and_save_on_disk( + self, + text_list: list[str], + save_path: str | PathLike[str], + prefix: str | None = None, + batch_size: int = 262144, + dtype: str = "float32", + ) -> np.memmap: + """ + Encode a list of texts and save the embeddings on disk using memmap. + + Args: + text_list (list[str]): list of texts + save_path (str): path to save the embeddings + prefix (str, optional): the prefix to use for encoding. Default to None. + dtype (str, optional): data type. Defaults to "float32". + batch_size (int): batch size. Defaults to 262144. + """ + self.set_output_numpy() + self.model.eval() + logger.info(f"use numpy") + + num_samples = len(text_list) + output_dim = self.get_output_dim() + + embeddings = np.memmap(save_path, dtype=dtype, mode="w+", shape=(num_samples, output_dim)) + + with sbert_multi_proc_pool(self.model) as pool: + with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar: + chunk_size = min( + self.batch_size * 4, + np.ceil(num_samples / len(pool["processes"])), + ) + logger.info(f"chunk size={chunk_size}") + for i in range(0, num_samples, batch_size): + batch: list[str] = text_list[i : i + batch_size] + batch = self._add_eos_func(batch) + batch_embeddings: np.ndarray = self.model.encode_multi_process( + batch, + pool=pool, + prompt=prefix, + chunk_size=chunk_size, + batch_size=self.batch_size, + normalize_embeddings=self.normalize_embeddings, + ) + embeddings[i : i + batch_size] = batch_embeddings + pbar.update(len(batch)) + + embeddings.flush() + return np.memmap(save_path, dtype=dtype, mode="r", shape=(num_samples, output_dim)) def _add_eos_func(self, text: str | list[str]) -> str | list[str]: try: From 7e8e031eec26c2c2ee438e03ed18a52de944f76d Mon Sep 17 00:00:00 2001 From: akiFQC Date: Wed, 7 Aug 2024 15:33:02 +0900 Subject: [PATCH 02/16] add chunk_size_factor --- src/jmteb/embedders/sbert_embedder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index cfea3a3..c9c11be 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -41,6 +41,7 @@ def __init__( truncate_dim: int | None = None, model_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, + chunk_size_factor: int = 4, ) -> None: model_kwargs = self._model_kwargs_parser(model_kwargs) self.model = SentenceTransformer( @@ -60,6 +61,7 @@ def __init__( self.add_eos = add_eos self.set_output_numpy() self.model.eval() + self.chunk_size_factor = 4 # override def _batch_encode_and_save_on_disk( @@ -92,7 +94,7 @@ def _batch_encode_and_save_on_disk( with sbert_multi_proc_pool(self.model) as pool: with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar: chunk_size = min( - self.batch_size * 4, + self.batch_size * self.chunk_size_factor, np.ceil(num_samples / len(pool["processes"])), ) logger.info(f"chunk size={chunk_size}") From 6fb4a6e5c3889a8b210827285068e7834cd74015 Mon Sep 17 00:00:00 2001 From: akiFQC Date: Wed, 7 Aug 2024 15:54:08 +0900 Subject: [PATCH 03/16] fix chunk_size_factor --- src/jmteb/embedders/sbert_embedder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index c9c11be..d76ca3d 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -41,7 +41,7 @@ def __init__( truncate_dim: int | None = None, model_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, - chunk_size_factor: int = 4, + chunk_size_factor: int = 128, ) -> None: model_kwargs = self._model_kwargs_parser(model_kwargs) self.model = SentenceTransformer( @@ -61,7 +61,7 @@ def __init__( self.add_eos = add_eos self.set_output_numpy() self.model.eval() - self.chunk_size_factor = 4 + self.chunk_size_factor = chunk_size_factor # override def _batch_encode_and_save_on_disk( From c361727d411354589eee70695b597573a110fbc9 Mon Sep 17 00:00:00 2001 From: akiFQC Date: Wed, 7 Aug 2024 16:02:59 +0900 Subject: [PATCH 04/16] small fix chunk_size --- src/jmteb/embedders/sbert_embedder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index d76ca3d..5989783 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -93,10 +93,10 @@ def _batch_encode_and_save_on_disk( with sbert_multi_proc_pool(self.model) as pool: with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar: - chunk_size = min( + chunk_size = int(min( self.batch_size * self.chunk_size_factor, np.ceil(num_samples / len(pool["processes"])), - ) + )) logger.info(f"chunk size={chunk_size}") for i in range(0, num_samples, batch_size): batch: list[str] = text_list[i : i + batch_size] From 2809731d88a52bdacaa31676b88b4b152dd90715 Mon Sep 17 00:00:00 2001 From: akiFQC Date: Wed, 7 Aug 2024 16:03:15 +0900 Subject: [PATCH 05/16] format --- src/jmteb/embedders/sbert_embedder.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index 5989783..50500e4 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -93,10 +93,12 @@ def _batch_encode_and_save_on_disk( with sbert_multi_proc_pool(self.model) as pool: with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar: - chunk_size = int(min( - self.batch_size * self.chunk_size_factor, - np.ceil(num_samples / len(pool["processes"])), - )) + chunk_size = int( + min( + self.batch_size * self.chunk_size_factor, + np.ceil(num_samples / len(pool["processes"])), + ) + ) logger.info(f"chunk size={chunk_size}") for i in range(0, num_samples, batch_size): batch: list[str] = text_list[i : i + batch_size] From 62368bcbbc7a931ea838fbd6617e3bccbeba244d Mon Sep 17 00:00:00 2001 From: akiFQC Date: Thu, 8 Aug 2024 15:42:22 +0900 Subject: [PATCH 06/16] add: code and tests of multi-gpu inference with pytorch DP --- .../embedders/data_parallel_sbert_embedder.py | 218 ++++++++++++++++++ src/jmteb/embedders/sbert_embedder.py | 12 + tests/embedders/test_dp_sbert.py | 38 +++ tests/embedders/test_sbert.py | 10 +- 4 files changed, 274 insertions(+), 4 deletions(-) create mode 100644 src/jmteb/embedders/data_parallel_sbert_embedder.py create mode 100644 tests/embedders/test_dp_sbert.py diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py new file mode 100644 index 0000000..470b57d --- /dev/null +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import logging +from typing import Literal + +import numpy as np +import torch +from loguru import logger +from sentence_transformers import SentenceTransformer +from sentence_transformers.quantization import quantize_embeddings +from sentence_transformers.util import batch_to_device, truncate_embeddings +from torch import Tensor +from tqdm.autonotebook import trange + +from jmteb.embedders.base import TextEmbedder + + +class DPSentenceTransformer(SentenceTransformer): + + def __init__(self, sbert_model: SentenceTransformer): + super(DPSentenceTransformer, self).__init__() + self.dp_model = torch.nn.DataParallel(sbert_model) + self.sbert = self.dp_model.module + + def forward(self, *args, **kargs): + return self.dp_model.forward(*args, **kargs) + + def encode( + self, + sentences: str | list[str], + prompt_name: str | None = None, + prompt: str | None = None, + batch_size: int = 64, + show_progress_bar: bool | None = None, + output_value: Literal["sentence_embedding", "token_embeddings"] | None = "sentence_embedding", + precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: str = None, + normalize_embeddings: bool = False, + ) -> list[Tensor] | np.ndarray | Tensor: + self.eval() + if show_progress_bar is None: + show_progress_bar = logger.level in (logging.INFO, logging.DEBUG) + + if convert_to_tensor: + convert_to_numpy = False + + if output_value != "sentence_embedding": + convert_to_tensor = False + convert_to_numpy = False + + input_was_string = False + if isinstance(sentences, str) or not hasattr( + sentences, "__len__" + ): # Cast an individual sentence to a list with length 1 + sentences = [sentences] + input_was_string = True + + if prompt is None: + if prompt_name is not None: + try: + prompt = self.sbert.prompts[prompt_name] + except KeyError: + raise ValueError( + f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(self.sbert.prompts.keys())!r}." + ) + elif self.default_prompt_name is not None: + prompt = self.sbert.prompts.get(self.sbert.default_prompt_name, None) + else: + if prompt_name is not None: + logger.warning( + "Encode with either a `prompt`, a `prompt_name`, or neither, but not both. " + "Ignoring the `prompt_name` in favor of `prompt`." + ) + + extra_features = {} + if prompt is not None: + sentences = [prompt + sentence for sentence in sentences] + + # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling + # Tracking the prompt length allow us to remove the prompt during pooling + tokenized_prompt = self.sbert.tokenize([prompt]) + if "input_ids" in tokenized_prompt: + extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1 + + all_embeddings = [] + length_sorted_idx = np.argsort([-self.sbert._text_length(sen) for sen in sentences]) + sentences_sorted = [sentences[idx] for idx in length_sorted_idx] + + for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): + sentences_batch = sentences_sorted[start_index : start_index + batch_size] + features = self.sbert.tokenize(sentences_batch) + features = batch_to_device(features, device) + features.update(extra_features) + + with torch.no_grad(): + out_features = self.forward(features) + + out_features["sentence_embedding"] = truncate_embeddings( + out_features["sentence_embedding"], self.sbert.truncate_dim + ) + + if output_value == "token_embeddings": + embeddings = [] + for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]): + last_mask_id = len(attention) - 1 + while last_mask_id > 0 and attention[last_mask_id].item() == 0: + last_mask_id -= 1 + + embeddings.append(token_emb[0 : last_mask_id + 1]) + elif output_value is None: # Return all outputs + embeddings = [] + for sent_idx in range(len(out_features["sentence_embedding"])): + row = {name: out_features[name][sent_idx] for name in out_features} + embeddings.append(row) + else: # Sentence embeddings + embeddings = out_features[output_value] + embeddings = embeddings.detach() + if normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + # fixes for #522 and #487 to avoid oom problems on gpu with large datasets + if convert_to_numpy: + embeddings = embeddings.cpu() + + all_embeddings.extend(embeddings) + + all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] + + if precision and precision != "float32": + all_embeddings = quantize_embeddings(all_embeddings, precision=precision) + + if convert_to_tensor: + if len(all_embeddings): + if isinstance(all_embeddings, np.ndarray): + all_embeddings = torch.from_numpy(all_embeddings) + else: + all_embeddings = torch.stack(all_embeddings) + else: + all_embeddings = torch.Tensor() + elif convert_to_numpy: + if not isinstance(all_embeddings, np.ndarray): + if all_embeddings and all_embeddings[0].dtype == torch.bfloat16: + all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings]) + else: + all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) + elif isinstance(all_embeddings, np.ndarray): + all_embeddings = [torch.from_numpy(embedding) for embedding in all_embeddings] + + if input_was_string: + all_embeddings = all_embeddings[0] + + return all_embeddings + + +class DataParallelSentenceBertEmbedder(TextEmbedder): + """SentenceBERT embedder with pytorch data parallel""" + + def __init__( + self, + model_name_or_path: str, + batch_size: int = 64, + normalize_embeddings: bool = False, + max_seq_length: int | None = None, + add_eos: bool = False, + truncate_dim: int | None = None, + model_kwargs: dict | None = None, + tokenizer_kwargs: dict | None = None, + ) -> None: + model_kwargs = self._model_kwargs_parser(model_kwargs) + model = SentenceTransformer( + model_name_or_path, + trust_remote_code=True, + truncate_dim=truncate_dim, + model_kwargs=model_kwargs, # https://github.com/UKPLab/sentence-transformers/blob/84f69fee6dcde023f46a8807e89bc99a7700ba82/sentence_transformers/SentenceTransformer.py#L81-L105 # noqa: E501 + tokenizer_kwargs=tokenizer_kwargs, + ) + self.dp_model = DPSentenceTransformer(sbert_model=model) + self.model = self.dp_model.sbert + if max_seq_length: + self.model.max_seq_length = max_seq_length + + self.batch_size = batch_size + self.normalize_embeddings = normalize_embeddings + self.max_seq_length = getattr(self.model, "max_seq_length", None) + self.add_eos = add_eos + + if "torch_dtype" in model_kwargs: + self.set_output_tensor() + else: + self.set_output_numpy() + + def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: + if self.add_eos: + text = self._add_eos_func(text) + return self.dp_model.encode( + text, + prompt=prefix, + convert_to_numpy=self.convert_to_numpy, + convert_to_tensor=self.convert_to_tensor, + batch_size=self.batch_size, + normalize_embeddings=self.normalize_embeddings, + ) + + def _add_eos_func(self, text: str | list[str]) -> str | list[str]: + try: + eos_token = getattr(self.model.savetokenizer, "eos_token") + except AttributeError: + return text + + if isinstance(text, str): + return text + eos_token + elif isinstance(text, list): + return [t + eos_token for t in text] + + def get_output_dim(self) -> int: + return self.model.get_sentence_embedding_dimension() diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index 50500e4..d4cb2f1 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -63,6 +63,18 @@ def __init__( self.model.eval() self.chunk_size_factor = chunk_size_factor + def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: + if self.add_eos: + text = self._add_eos_func(text) + return self.model.encode( + text, + prompt=prefix, + convert_to_numpy=self.convert_to_numpy, + convert_to_tensor=self.convert_to_tensor, + batch_size=self.batch_size, + normalize_embeddings=self.normalize_embeddings, + ) + # override def _batch_encode_and_save_on_disk( self, diff --git a/tests/embedders/test_dp_sbert.py b/tests/embedders/test_dp_sbert.py new file mode 100644 index 0000000..241d5a6 --- /dev/null +++ b/tests/embedders/test_dp_sbert.py @@ -0,0 +1,38 @@ +import numpy as np +import torch + +from jmteb.embedders.data_parallel_sbert_embedder import ( + DataParallelSentenceBertEmbedder, +) + +MODEL_NAME_OR_PATH = "prajjwal1/bert-tiny" +OUTPUT_DIM = 128 + + +class TestSentenceBertEmbedder: + def setup_class(cls): + cls.model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH) + + def test_encode(self): + embeddings = self.model.encode("任意のテキスト") + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == (OUTPUT_DIM,) + + def test_get_output_dim(self): + assert self.model.get_output_dim() == OUTPUT_DIM + + def test_tokenizer_kwargs(self): + assert self.model.model.tokenizer.sep_token == "[SEP]" + model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH, tokenizer_kwargs={"sep_token": ""}) + assert model.model.tokenizer.sep_token == "" + + def test_model_kwargs(self): + model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.float16}) + assert model.convert_to_tensor + assert model.encode("任意のテキスト").dtype is torch.float16 + + def test_bf16(self): + # As numpy doesn't support native bfloat16, add a test case for bf16 + model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.bfloat16}) + assert model.convert_to_tensor + assert model.encode("任意のテキスト").dtype is torch.bfloat16 diff --git a/tests/embedders/test_sbert.py b/tests/embedders/test_sbert.py index aa184f9..e38472d 100644 --- a/tests/embedders/test_sbert.py +++ b/tests/embedders/test_sbert.py @@ -26,11 +26,13 @@ def test_tokenizer_kwargs(self): def test_model_kwargs(self): model = SentenceBertEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.float16}) - assert model.convert_to_tensor - assert model.encode("任意のテキスト").dtype is torch.float16 + assert not model.convert_to_tensor + assert model.convert_to_numpy + assert model.encode("任意のテキスト").dtype is np.dtype("float16") def test_bf16(self): # As numpy doesn't support native bfloat16, add a test case for bf16 model = SentenceBertEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.bfloat16}) - assert model.convert_to_tensor - assert model.encode("任意のテキスト").dtype is torch.bfloat16 + assert not model.convert_to_tensor + assert model.convert_to_numpy + assert model.encode("任意のテキスト").dtype is np.dtype("float32") From b7376c05bf7f7a0400285913c6d65af139b62256 Mon Sep 17 00:00:00 2001 From: akiFQC Date: Thu, 8 Aug 2024 16:23:42 +0900 Subject: [PATCH 07/16] update init --- src/jmteb/embedders/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/jmteb/embedders/__init__.py b/src/jmteb/embedders/__init__.py index eeffb98..a0649f7 100644 --- a/src/jmteb/embedders/__init__.py +++ b/src/jmteb/embedders/__init__.py @@ -2,3 +2,4 @@ from jmteb.embedders.openai_embedder import OpenAIEmbedder from jmteb.embedders.sbert_embedder import SentenceBertEmbedder from jmteb.embedders.transformers_embedder import TransformersEmbedder +from jmteb.embedders.data_parallel_sbert_embedder import DataParallelSentenceBertEmbedder From 9508f3c918324c3c8e47855a70941d194019d272 Mon Sep 17 00:00:00 2001 From: akiFQC Date: Thu, 8 Aug 2024 16:34:51 +0900 Subject: [PATCH 08/16] debug DP --- src/jmteb/embedders/data_parallel_sbert_embedder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index 470b57d..dceb881 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -91,7 +91,6 @@ def encode( for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): sentences_batch = sentences_sorted[start_index : start_index + batch_size] features = self.sbert.tokenize(sentences_batch) - features = batch_to_device(features, device) features.update(extra_features) with torch.no_grad(): From b9a50c61263ee4e3dd21675b12db0b794d965bf4 Mon Sep 17 00:00:00 2001 From: akiFQC Date: Thu, 8 Aug 2024 16:52:58 +0900 Subject: [PATCH 09/16] revert sbert embedder --- src/jmteb/embedders/__init__.py | 4 +- src/jmteb/embedders/sbert_embedder.py | 85 ++------------------------- tests/embedders/test_sbert.py | 10 ++-- 3 files changed, 13 insertions(+), 86 deletions(-) diff --git a/src/jmteb/embedders/__init__.py b/src/jmteb/embedders/__init__.py index a0649f7..f28f038 100644 --- a/src/jmteb/embedders/__init__.py +++ b/src/jmteb/embedders/__init__.py @@ -1,5 +1,7 @@ from jmteb.embedders.base import TextEmbedder +from jmteb.embedders.data_parallel_sbert_embedder import ( + DataParallelSentenceBertEmbedder, +) from jmteb.embedders.openai_embedder import OpenAIEmbedder from jmteb.embedders.sbert_embedder import SentenceBertEmbedder from jmteb.embedders.transformers_embedder import TransformersEmbedder -from jmteb.embedders.data_parallel_sbert_embedder import DataParallelSentenceBertEmbedder diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index d4cb2f1..0188e7d 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -1,32 +1,11 @@ from __future__ import annotations -from contextlib import contextmanager -from os import PathLike -from pathlib import Path -from typing import Optional - import numpy as np -import torch -import tqdm -from loguru import logger from sentence_transformers import SentenceTransformer from jmteb.embedders.base import TextEmbedder -@contextmanager -def sbert_multi_proc_pool(sbert_model: SentenceTransformer, target_devices: Optional[list[str]] = None): - pool = sbert_model.start_multi_process_pool(target_devices=target_devices) - logger.info("pool of encoding processing: ") - for k, v in pool.items(): - logger.info(f" {k}: {v}") - try: - yield pool - finally: - logger.info("stop pool") - sbert_model.stop_multi_process_pool(pool) - - class SentenceBertEmbedder(TextEmbedder): """SentenceBERT embedder.""" @@ -41,7 +20,6 @@ def __init__( truncate_dim: int | None = None, model_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, - chunk_size_factor: int = 128, ) -> None: model_kwargs = self._model_kwargs_parser(model_kwargs) self.model = SentenceTransformer( @@ -59,9 +37,11 @@ def __init__( self.normalize_embeddings = normalize_embeddings self.max_seq_length = getattr(self.model, "max_seq_length", None) self.add_eos = add_eos - self.set_output_numpy() - self.model.eval() - self.chunk_size_factor = chunk_size_factor + + if "torch_dtype" in model_kwargs: + self.set_output_tensor() + else: + self.set_output_numpy() def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: if self.add_eos: @@ -72,63 +52,10 @@ def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray convert_to_numpy=self.convert_to_numpy, convert_to_tensor=self.convert_to_tensor, batch_size=self.batch_size, + device=self.device, normalize_embeddings=self.normalize_embeddings, ) - # override - def _batch_encode_and_save_on_disk( - self, - text_list: list[str], - save_path: str | PathLike[str], - prefix: str | None = None, - batch_size: int = 262144, - dtype: str = "float32", - ) -> np.memmap: - """ - Encode a list of texts and save the embeddings on disk using memmap. - - Args: - text_list (list[str]): list of texts - save_path (str): path to save the embeddings - prefix (str, optional): the prefix to use for encoding. Default to None. - dtype (str, optional): data type. Defaults to "float32". - batch_size (int): batch size. Defaults to 262144. - """ - self.set_output_numpy() - self.model.eval() - logger.info(f"use numpy") - - num_samples = len(text_list) - output_dim = self.get_output_dim() - - embeddings = np.memmap(save_path, dtype=dtype, mode="w+", shape=(num_samples, output_dim)) - - with sbert_multi_proc_pool(self.model) as pool: - with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar: - chunk_size = int( - min( - self.batch_size * self.chunk_size_factor, - np.ceil(num_samples / len(pool["processes"])), - ) - ) - logger.info(f"chunk size={chunk_size}") - for i in range(0, num_samples, batch_size): - batch: list[str] = text_list[i : i + batch_size] - batch = self._add_eos_func(batch) - batch_embeddings: np.ndarray = self.model.encode_multi_process( - batch, - pool=pool, - prompt=prefix, - chunk_size=chunk_size, - batch_size=self.batch_size, - normalize_embeddings=self.normalize_embeddings, - ) - embeddings[i : i + batch_size] = batch_embeddings - pbar.update(len(batch)) - - embeddings.flush() - return np.memmap(save_path, dtype=dtype, mode="r", shape=(num_samples, output_dim)) - def _add_eos_func(self, text: str | list[str]) -> str | list[str]: try: eos_token = getattr(self.model.tokenizer, "eos_token") diff --git a/tests/embedders/test_sbert.py b/tests/embedders/test_sbert.py index e38472d..aa184f9 100644 --- a/tests/embedders/test_sbert.py +++ b/tests/embedders/test_sbert.py @@ -26,13 +26,11 @@ def test_tokenizer_kwargs(self): def test_model_kwargs(self): model = SentenceBertEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.float16}) - assert not model.convert_to_tensor - assert model.convert_to_numpy - assert model.encode("任意のテキスト").dtype is np.dtype("float16") + assert model.convert_to_tensor + assert model.encode("任意のテキスト").dtype is torch.float16 def test_bf16(self): # As numpy doesn't support native bfloat16, add a test case for bf16 model = SentenceBertEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.bfloat16}) - assert not model.convert_to_tensor - assert model.convert_to_numpy - assert model.encode("任意のテキスト").dtype is np.dtype("float32") + assert model.convert_to_tensor + assert model.encode("任意のテキスト").dtype is torch.bfloat16 From c6f079a1a95aed354060b99122615bad432a05b2 Mon Sep 17 00:00:00 2001 From: akiFQC Date: Thu, 8 Aug 2024 17:15:04 +0900 Subject: [PATCH 10/16] format --- src/jmteb/embedders/data_parallel_sbert_embedder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index dceb881..07c5a67 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -8,7 +8,7 @@ from loguru import logger from sentence_transformers import SentenceTransformer from sentence_transformers.quantization import quantize_embeddings -from sentence_transformers.util import batch_to_device, truncate_embeddings +from sentence_transformers.util import truncate_embeddings from torch import Tensor from tqdm.autonotebook import trange @@ -63,7 +63,8 @@ def encode( prompt = self.sbert.prompts[prompt_name] except KeyError: raise ValueError( - f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(self.sbert.prompts.keys())!r}." + f"Prompt name '{prompt_name}' not found in the configured " + f"prompts dictionary with keys {list(self.sbert.prompts.keys())!r}." ) elif self.default_prompt_name is not None: prompt = self.sbert.prompts.get(self.sbert.default_prompt_name, None) From 4261cf5f857e564694d5b428a32ccb3e5b482fbc Mon Sep 17 00:00:00 2001 From: akiFQC Date: Thu, 8 Aug 2024 18:43:18 +0900 Subject: [PATCH 11/16] find_executable_batch_size --- .../embedders/data_parallel_sbert_embedder.py | 40 ++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index 07c5a67..fba9ceb 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -5,6 +5,7 @@ import numpy as np import torch +from accelerate.utils import find_executable_batch_size from loguru import logger from sentence_transformers import SentenceTransformer from sentence_transformers.quantization import quantize_embeddings @@ -167,6 +168,7 @@ def __init__( truncate_dim: int | None = None, model_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, + auto_find_batch_size: bool = True, ) -> None: model_kwargs = self._model_kwargs_parser(model_kwargs) model = SentenceTransformer( @@ -180,11 +182,12 @@ def __init__( self.model = self.dp_model.sbert if max_seq_length: self.model.max_seq_length = max_seq_length - - self.batch_size = batch_size + self.initital_batch_size = batch_size + self.batch_size = int(self.initital_batch_size) self.normalize_embeddings = normalize_embeddings self.max_seq_length = getattr(self.model, "max_seq_length", None) self.add_eos = add_eos + self.auto_find_batch_size = auto_find_batch_size if "torch_dtype" in model_kwargs: self.set_output_tensor() @@ -194,14 +197,31 @@ def __init__( def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: if self.add_eos: text = self._add_eos_func(text) - return self.dp_model.encode( - text, - prompt=prefix, - convert_to_numpy=self.convert_to_numpy, - convert_to_tensor=self.convert_to_tensor, - batch_size=self.batch_size, - normalize_embeddings=self.normalize_embeddings, - ) + if self.auto_find_batch_size: + # wrap function + @find_executable_batch_size(starting_batch_size=self.batch_size) + def _encode_with_auto_batch_size(batch_size, self, text, prefix): + out = self.dp_model.encode( + text, + prompt=prefix, + convert_to_numpy=self.convert_to_numpy, + convert_to_tensor=self.convert_to_tensor, + batch_size=batch_size, + normalize_embeddings=self.normalize_embeddings, + ) + self.batch_size = batch_size + return out + + return _encode_with_auto_batch_size(self, text, prefix) + else: + return self.dp_model.encode( + text, + prompt=prefix, + convert_to_numpy=self.convert_to_numpy, + convert_to_tensor=self.convert_to_tensor, + batch_size=self.batch_size, + normalize_embeddings=self.normalize_embeddings, + ) def _add_eos_func(self, text: str | list[str]) -> str | list[str]: try: From 39f98a3a579849d263ec83f4d5aa74c3d0994710 Mon Sep 17 00:00:00 2001 From: akiFQC Date: Thu, 8 Aug 2024 18:47:30 +0900 Subject: [PATCH 12/16] add comment --- src/jmteb/embedders/data_parallel_sbert_embedder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index fba9ceb..698fe0f 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -17,6 +17,7 @@ class DPSentenceTransformer(SentenceTransformer): + """SentenceBERT with pytorch torch.nn.DataParallel""" def __init__(self, sbert_model: SentenceTransformer): super(DPSentenceTransformer, self).__init__() @@ -209,6 +210,7 @@ def _encode_with_auto_batch_size(batch_size, self, text, prefix): batch_size=batch_size, normalize_embeddings=self.normalize_embeddings, ) + self.batch_size = batch_size return out From abe4f8816440d72acfa6e79aeebcfdd41cb15687 Mon Sep 17 00:00:00 2001 From: akiFQC Date: Fri, 9 Aug 2024 09:41:21 +0900 Subject: [PATCH 13/16] debug --- src/jmteb/evaluators/reranking/evaluator.py | 2 +- src/jmteb/evaluators/retrieval/evaluator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jmteb/evaluators/reranking/evaluator.py b/src/jmteb/evaluators/reranking/evaluator.py index f2e136b..85b5161 100644 --- a/src/jmteb/evaluators/reranking/evaluator.py +++ b/src/jmteb/evaluators/reranking/evaluator.py @@ -151,7 +151,7 @@ def _compute_metrics( with tqdm.tqdm(total=len(query_dataset), desc="Reranking docs") as pbar: if torch.cuda.is_available(): - if dist.is_available(): + if dist.is_available() and dist.is_torchelastic_launched(): device = f"cuda:{dist.get_rank()}" else: device = "cuda" diff --git a/src/jmteb/evaluators/retrieval/evaluator.py b/src/jmteb/evaluators/retrieval/evaluator.py index 3d91633..2d9bdd0 100644 --- a/src/jmteb/evaluators/retrieval/evaluator.py +++ b/src/jmteb/evaluators/retrieval/evaluator.py @@ -160,7 +160,7 @@ def _compute_metrics( doc_embeddings_chunk = doc_embeddings[offset : offset + self.doc_chunk_size] if torch.cuda.is_available(): - if dist.is_available(): + if dist.is_available() and dist.is_torchelastic_launched(): device = f"cuda:{dist.get_rank()}" else: device = "cuda" From 61aa4dac34206ddac223dbaeba8d5c3f5ee4b876 Mon Sep 17 00:00:00 2001 From: akiFQC Date: Fri, 9 Aug 2024 15:30:40 +0900 Subject: [PATCH 14/16] fix to review --- src/jmteb/embedders/data_parallel_sbert_embedder.py | 2 +- tests/embedders/test_dp_sbert.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index 698fe0f..e86f661 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -227,7 +227,7 @@ def _encode_with_auto_batch_size(batch_size, self, text, prefix): def _add_eos_func(self, text: str | list[str]) -> str | list[str]: try: - eos_token = getattr(self.model.savetokenizer, "eos_token") + eos_token = getattr(self.model.tokenizer, "eos_token") except AttributeError: return text diff --git a/tests/embedders/test_dp_sbert.py b/tests/embedders/test_dp_sbert.py index 241d5a6..028e240 100644 --- a/tests/embedders/test_dp_sbert.py +++ b/tests/embedders/test_dp_sbert.py @@ -9,7 +9,7 @@ OUTPUT_DIM = 128 -class TestSentenceBertEmbedder: +class TestDPSentenceBertEmbedder: def setup_class(cls): cls.model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH) From 79a6c8b26ad95a868026ce3e6990b4261767524e Mon Sep 17 00:00:00 2001 From: akiFQC Date: Fri, 9 Aug 2024 16:41:41 +0900 Subject: [PATCH 15/16] update --- src/jmteb/embedders/data_parallel_sbert_embedder.py | 4 +++- src/jmteb/evaluators/reranking/evaluator.py | 2 +- src/jmteb/evaluators/retrieval/evaluator.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index e86f661..d12d3e3 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import sys from typing import Literal import numpy as np @@ -43,7 +44,8 @@ def encode( ) -> list[Tensor] | np.ndarray | Tensor: self.eval() if show_progress_bar is None: - show_progress_bar = logger.level in (logging.INFO, logging.DEBUG) + logger.remove() + logger.add(sys.stderr, level="INFO") if convert_to_tensor: convert_to_numpy = False diff --git a/src/jmteb/evaluators/reranking/evaluator.py b/src/jmteb/evaluators/reranking/evaluator.py index 85b5161..5c4ba34 100644 --- a/src/jmteb/evaluators/reranking/evaluator.py +++ b/src/jmteb/evaluators/reranking/evaluator.py @@ -151,7 +151,7 @@ def _compute_metrics( with tqdm.tqdm(total=len(query_dataset), desc="Reranking docs") as pbar: if torch.cuda.is_available(): - if dist.is_available() and dist.is_torchelastic_launched(): + if dist.is_torchelastic_launched(): device = f"cuda:{dist.get_rank()}" else: device = "cuda" diff --git a/src/jmteb/evaluators/retrieval/evaluator.py b/src/jmteb/evaluators/retrieval/evaluator.py index 2d9bdd0..73c0981 100644 --- a/src/jmteb/evaluators/retrieval/evaluator.py +++ b/src/jmteb/evaluators/retrieval/evaluator.py @@ -160,7 +160,7 @@ def _compute_metrics( doc_embeddings_chunk = doc_embeddings[offset : offset + self.doc_chunk_size] if torch.cuda.is_available(): - if dist.is_available() and dist.is_torchelastic_launched(): + if dist.is_torchelastic_launched(): device = f"cuda:{dist.get_rank()}" else: device = "cuda" From 56f415dad0fce5f14585ee22b4cd1492f5b5eb77 Mon Sep 17 00:00:00 2001 From: akiFQC Date: Fri, 9 Aug 2024 17:20:26 +0900 Subject: [PATCH 16/16] del unused import --- src/jmteb/embedders/data_parallel_sbert_embedder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index d12d3e3..6fb7e87 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging import sys from typing import Literal