Skip to content

Commit

Permalink
Merge pull request #53 from sbintuitions/improve/batch_size_setting
Browse files Browse the repository at this point in the history
Improve: batch size setting and multi GPU inference with SentenceTransformers+DP
  • Loading branch information
akiFQC authored Aug 13, 2024
2 parents dc96352 + 56f415d commit ca71155
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/jmteb/embedders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from jmteb.embedders.base import TextEmbedder
from jmteb.embedders.data_parallel_sbert_embedder import (
DataParallelSentenceBertEmbedder,
)
from jmteb.embedders.openai_embedder import OpenAIEmbedder
from jmteb.embedders.sbert_embedder import SentenceBertEmbedder
from jmteb.embedders.transformers_embedder import TransformersEmbedder
7 changes: 3 additions & 4 deletions src/jmteb/embedders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TextEmbedder(ABC):

convert_to_tensor: bool
convert_to_numpy: bool
_chunk_size: int = 262144 # 2^18

def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray | torch.Tensor:
"""Convert a text string or a list of texts to embedding.
Expand All @@ -40,7 +41,7 @@ def _batch_encode_and_save_on_disk(
text_list: list[str],
save_path: str | PathLike[str],
prefix: str | None = None,
batch_size: int = 64,
batch_size: int = 262144,
dtype: str = "float32",
) -> np.memmap | torch.Tensor:
"""
Expand Down Expand Up @@ -81,7 +82,6 @@ def batch_encode_with_cache(
prefix: str | None = None,
cache_path: str | PathLike[str] | None = None,
overwrite_cache: bool = False,
batch_size: int = 64,
dtype: str = "float32",
) -> np.ndarray | torch.Tensor:
"""
Expand All @@ -92,7 +92,6 @@ def batch_encode_with_cache(
prefix (str, optional): the prefix to use for encoding. Default to None.
cache_path (str, optional): path to save the embeddings. Defaults to None.
overwrite_cache (bool, optional): whether to overwrite the cache. Defaults to False.
batch_size (int): batch size. Defaults to 64.
dtype (str, optional): data type. Defaults to "float32".
"""

Expand All @@ -106,7 +105,7 @@ def batch_encode_with_cache(

logger.info(f"Encoding and saving embeddings to {cache_path}")
embeddings = self._batch_encode_and_save_on_disk(
text_list, cache_path, prefix=prefix, batch_size=batch_size, dtype=dtype
text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype
)
return embeddings

Expand Down
241 changes: 241 additions & 0 deletions src/jmteb/embedders/data_parallel_sbert_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from __future__ import annotations

import sys
from typing import Literal

import numpy as np
import torch
from accelerate.utils import find_executable_batch_size
from loguru import logger
from sentence_transformers import SentenceTransformer
from sentence_transformers.quantization import quantize_embeddings
from sentence_transformers.util import truncate_embeddings
from torch import Tensor
from tqdm.autonotebook import trange

from jmteb.embedders.base import TextEmbedder


class DPSentenceTransformer(SentenceTransformer):
"""SentenceBERT with pytorch torch.nn.DataParallel"""

def __init__(self, sbert_model: SentenceTransformer):
super(DPSentenceTransformer, self).__init__()
self.dp_model = torch.nn.DataParallel(sbert_model)
self.sbert = self.dp_model.module

def forward(self, *args, **kargs):
return self.dp_model.forward(*args, **kargs)

def encode(
self,
sentences: str | list[str],
prompt_name: str | None = None,
prompt: str | None = None,
batch_size: int = 64,
show_progress_bar: bool | None = None,
output_value: Literal["sentence_embedding", "token_embeddings"] | None = "sentence_embedding",
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
normalize_embeddings: bool = False,
) -> list[Tensor] | np.ndarray | Tensor:
self.eval()
if show_progress_bar is None:
logger.remove()
logger.add(sys.stderr, level="INFO")

if convert_to_tensor:
convert_to_numpy = False

if output_value != "sentence_embedding":
convert_to_tensor = False
convert_to_numpy = False

input_was_string = False
if isinstance(sentences, str) or not hasattr(
sentences, "__len__"
): # Cast an individual sentence to a list with length 1
sentences = [sentences]
input_was_string = True

if prompt is None:
if prompt_name is not None:
try:
prompt = self.sbert.prompts[prompt_name]
except KeyError:
raise ValueError(
f"Prompt name '{prompt_name}' not found in the configured "
f"prompts dictionary with keys {list(self.sbert.prompts.keys())!r}."
)
elif self.default_prompt_name is not None:
prompt = self.sbert.prompts.get(self.sbert.default_prompt_name, None)
else:
if prompt_name is not None:
logger.warning(
"Encode with either a `prompt`, a `prompt_name`, or neither, but not both. "
"Ignoring the `prompt_name` in favor of `prompt`."
)

extra_features = {}
if prompt is not None:
sentences = [prompt + sentence for sentence in sentences]

# Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
# Tracking the prompt length allow us to remove the prompt during pooling
tokenized_prompt = self.sbert.tokenize([prompt])
if "input_ids" in tokenized_prompt:
extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1

all_embeddings = []
length_sorted_idx = np.argsort([-self.sbert._text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
sentences_batch = sentences_sorted[start_index : start_index + batch_size]
features = self.sbert.tokenize(sentences_batch)
features.update(extra_features)

with torch.no_grad():
out_features = self.forward(features)

out_features["sentence_embedding"] = truncate_embeddings(
out_features["sentence_embedding"], self.sbert.truncate_dim
)

if output_value == "token_embeddings":
embeddings = []
for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]):
last_mask_id = len(attention) - 1
while last_mask_id > 0 and attention[last_mask_id].item() == 0:
last_mask_id -= 1

embeddings.append(token_emb[0 : last_mask_id + 1])
elif output_value is None: # Return all outputs
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {name: out_features[name][sent_idx] for name in out_features}
embeddings.append(row)
else: # Sentence embeddings
embeddings = out_features[output_value]
embeddings = embeddings.detach()
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

# fixes for #522 and #487 to avoid oom problems on gpu with large datasets
if convert_to_numpy:
embeddings = embeddings.cpu()

all_embeddings.extend(embeddings)

all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]

if precision and precision != "float32":
all_embeddings = quantize_embeddings(all_embeddings, precision=precision)

if convert_to_tensor:
if len(all_embeddings):
if isinstance(all_embeddings, np.ndarray):
all_embeddings = torch.from_numpy(all_embeddings)
else:
all_embeddings = torch.stack(all_embeddings)
else:
all_embeddings = torch.Tensor()
elif convert_to_numpy:
if not isinstance(all_embeddings, np.ndarray):
if all_embeddings and all_embeddings[0].dtype == torch.bfloat16:
all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings])
else:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
elif isinstance(all_embeddings, np.ndarray):
all_embeddings = [torch.from_numpy(embedding) for embedding in all_embeddings]

if input_was_string:
all_embeddings = all_embeddings[0]

return all_embeddings


class DataParallelSentenceBertEmbedder(TextEmbedder):
"""SentenceBERT embedder with pytorch data parallel"""

def __init__(
self,
model_name_or_path: str,
batch_size: int = 64,
normalize_embeddings: bool = False,
max_seq_length: int | None = None,
add_eos: bool = False,
truncate_dim: int | None = None,
model_kwargs: dict | None = None,
tokenizer_kwargs: dict | None = None,
auto_find_batch_size: bool = True,
) -> None:
model_kwargs = self._model_kwargs_parser(model_kwargs)
model = SentenceTransformer(
model_name_or_path,
trust_remote_code=True,
truncate_dim=truncate_dim,
model_kwargs=model_kwargs, # https://github.com/UKPLab/sentence-transformers/blob/84f69fee6dcde023f46a8807e89bc99a7700ba82/sentence_transformers/SentenceTransformer.py#L81-L105 # noqa: E501
tokenizer_kwargs=tokenizer_kwargs,
)
self.dp_model = DPSentenceTransformer(sbert_model=model)
self.model = self.dp_model.sbert
if max_seq_length:
self.model.max_seq_length = max_seq_length
self.initital_batch_size = batch_size
self.batch_size = int(self.initital_batch_size)
self.normalize_embeddings = normalize_embeddings
self.max_seq_length = getattr(self.model, "max_seq_length", None)
self.add_eos = add_eos
self.auto_find_batch_size = auto_find_batch_size

if "torch_dtype" in model_kwargs:
self.set_output_tensor()
else:
self.set_output_numpy()

def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray:
if self.add_eos:
text = self._add_eos_func(text)
if self.auto_find_batch_size:
# wrap function
@find_executable_batch_size(starting_batch_size=self.batch_size)
def _encode_with_auto_batch_size(batch_size, self, text, prefix):
out = self.dp_model.encode(
text,
prompt=prefix,
convert_to_numpy=self.convert_to_numpy,
convert_to_tensor=self.convert_to_tensor,
batch_size=batch_size,
normalize_embeddings=self.normalize_embeddings,
)

self.batch_size = batch_size
return out

return _encode_with_auto_batch_size(self, text, prefix)
else:
return self.dp_model.encode(
text,
prompt=prefix,
convert_to_numpy=self.convert_to_numpy,
convert_to_tensor=self.convert_to_tensor,
batch_size=self.batch_size,
normalize_embeddings=self.normalize_embeddings,
)

def _add_eos_func(self, text: str | list[str]) -> str | list[str]:
try:
eos_token = getattr(self.model.tokenizer, "eos_token")
except AttributeError:
return text

if isinstance(text, str):
return text + eos_token
elif isinstance(text, list):
return [t + eos_token for t in text]

def get_output_dim(self) -> int:
return self.model.get_sentence_embedding_dimension()
2 changes: 1 addition & 1 deletion src/jmteb/evaluators/reranking/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _compute_metrics(

with tqdm.tqdm(total=len(query_dataset), desc="Reranking docs") as pbar:
if torch.cuda.is_available():
if dist.is_available():
if dist.is_torchelastic_launched():
device = f"cuda:{dist.get_rank()}"
else:
device = "cuda"
Expand Down
2 changes: 1 addition & 1 deletion src/jmteb/evaluators/retrieval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _compute_metrics(
doc_embeddings_chunk = doc_embeddings[offset : offset + self.doc_chunk_size]

if torch.cuda.is_available():
if dist.is_available():
if dist.is_torchelastic_launched():
device = f"cuda:{dist.get_rank()}"
else:
device = "cuda"
Expand Down
38 changes: 38 additions & 0 deletions tests/embedders/test_dp_sbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import torch

from jmteb.embedders.data_parallel_sbert_embedder import (
DataParallelSentenceBertEmbedder,
)

MODEL_NAME_OR_PATH = "prajjwal1/bert-tiny"
OUTPUT_DIM = 128


class TestDPSentenceBertEmbedder:
def setup_class(cls):
cls.model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH)

def test_encode(self):
embeddings = self.model.encode("任意のテキスト")
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (OUTPUT_DIM,)

def test_get_output_dim(self):
assert self.model.get_output_dim() == OUTPUT_DIM

def test_tokenizer_kwargs(self):
assert self.model.model.tokenizer.sep_token == "[SEP]"
model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH, tokenizer_kwargs={"sep_token": "<sep>"})
assert model.model.tokenizer.sep_token == "<sep>"

def test_model_kwargs(self):
model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.float16})
assert model.convert_to_tensor
assert model.encode("任意のテキスト").dtype is torch.float16

def test_bf16(self):
# As numpy doesn't support native bfloat16, add a test case for bf16
model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.bfloat16})
assert model.convert_to_tensor
assert model.encode("任意のテキスト").dtype is torch.bfloat16

0 comments on commit ca71155

Please sign in to comment.