Skip to content

Commit

Permalink
Rewrite encoding in TransformersEmbedder
Browse files Browse the repository at this point in the history
  • Loading branch information
lsz05 committed Jul 31, 2024
1 parent 7881bd7 commit 5127941
Showing 1 changed file with 106 additions and 58 deletions.
164 changes: 106 additions & 58 deletions src/jmteb/embedders/transformers_embedder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
import os
from pathlib import Path
from typing import Literal

import numpy as np
import torch
import tqdm
from accelerate import PartialState
from accelerate.utils import gather_object
from loguru import logger
from sentence_transformers.models import Pooling
from tqdm.autonotebook import trange
from torch import Tensor
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer

from jmteb.embedders.base import TextEmbedder
Expand All @@ -28,11 +29,15 @@ def __init__(
pooling_mode: str | None = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
encode_method_name: str | None = None,
encode_method_text_argument: str = "text",
encode_method_prefix_argument: str = "prefix",
) -> None:
model_kwargs = self._model_kwargs_parser(model_kwargs)
self.model: PreTrainedModel = AutoModel.from_pretrained(
model_name_or_path, trust_remote_code=True, **model_kwargs
)
logger.info(f"Model loaded:\n{self.model}")
self.batch_size = batch_size
if not device and torch.cuda.is_available():
self.device = "cuda"
Expand All @@ -41,6 +46,12 @@ def __init__(
self.normalize_embeddings = normalize_embeddings

self.distributed_state = PartialState() if torch.cuda.device_count() > 1 and self.device == "cuda" else None
if self.distributed_state and hasattr(self.distributed_state, "num_processes"):
assert (
self.batch_size % self.distributed_state.num_processes == 0
), f"""`batch_size` should be an integer multiple of the number of available GPUs,
but got {batch_size=}, {torch.cuda.device_count()=}. Note that `batch_size` is global batch size."""
logger.info(f"Distribution state: {self.distributed_state}")
if self.distributed_state:
self.model.to(self.distributed_state.device)
else:
Expand Down Expand Up @@ -79,91 +90,128 @@ def __init__(
else:
self.output_dim = self.pooling.get_sentence_embedding_dimension()

if "torch_dtype" in model_kwargs:
self.set_output_tensor()
else:
self.set_output_numpy()
# If the network has a built-in encoding method, use it instead of `_encode`
self.encode_method_name = encode_method_name
self.encode_method_text_argument = encode_method_text_argument
self.encode_method_prefix_argument = encode_method_prefix_argument

def get_output_dim(self) -> int:
return self.output_dim

def batch_encode_with_cache(
self,
text_list: list[str],
prefix: str | None = None,
cache_path: str | os.PathLike[str] | None = None,
overwrite_cache: bool = False,
batch_size: int = 64,
dtype: str = "float32",
) -> Tensor:
if cache_path is None:
logger.info("Encoding embeddings")
return self.encode(text_list, prefix=prefix).to(self._torch_dtype_parser(dtype))

if Path(cache_path).exists() and not overwrite_cache:
logger.info(f"Loading embeddings from {cache_path}")
return torch.load(cache_path)

logger.info(f"Encoding and saving embeddings to {cache_path}")
embeddings = self._batch_encode_and_save_on_disk(
text_list, cache_path, prefix=prefix, batch_size=batch_size, dtype=dtype
)
return embeddings

def _batch_encode_and_save_on_disk(
self,
text_list: list[str],
save_path: str | os.PathLike[str],
prefix: str | None = None,
batch_size: int = 64,
dtype: str = "float32",
) -> torch.Tensor:
num_samples = len(text_list)
output_dim = self.get_output_dim()
embeddings = torch.empty((num_samples, output_dim), dtype=self._torch_dtype_parser(dtype))

with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar:
for i in range(0, num_samples, batch_size):
batch = text_list[i : i + batch_size]
batch_embeddings: torch.Tensor = self.encode(batch, prefix)
embeddings[i : i + batch_size] = batch_embeddings
pbar.update(len(batch))

torch.save(embeddings, save_path)
return embeddings

def encode(
self,
text: str | list[str],
prefix: str | None = None,
show_progress_bar: bool = True,
):
dtype: Literal["float32", "float16", "bfloat16"] | None = None,
) -> torch.Tensor:
if self.distributed_state:
embeddings = self._encode_distributed(text, prefix)
else:
embeddings = self._encode(text, prefix)
return embeddings.to(dtype=dtype)

def _encode(self, text: str | list[str], prefix: str | None = None) -> torch.Tensor:
if isinstance(text, str):
text = [text]
text_was_str = True
else:
text_was_str = False

all_embeddings = []
length_sorted_idx = np.argsort([-len(t) for t in text])
text_sorted = [text[idx] for idx in length_sorted_idx]

for start_index in trange(0, len(text), self.batch_size, desc="Batches", disable=not show_progress_bar):
text_batch = text_sorted[start_index : start_index + self.batch_size]
if self.distributed_state:
batch_embeddings = self._encode_batch_distributed(text_batch, prefix)
else:
batch_embeddings = self._encode_batch(text_batch, prefix)
all_embeddings.extend(batch_embeddings)

all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
if self.add_eos:
text = self._add_eos_func(text)

if len(all_embeddings):
all_embeddings = torch.stack(all_embeddings)
else:
all_embeddings = torch.Tensor()
if self.encode_method_name and hasattr(self.model, self.encode_method_name):
# ensure the built-in encoding method accepts positional arguments for text and prefix
sentence_embeddings = getattr(self.model, self.encode_method_name)(
**{self.encode_method_text_argument: text, self.encode_method_prefix_argument: prefix}
)
if not isinstance(sentence_embeddings, Tensor):
sentence_embeddings = Tensor(sentence_embeddings)

if text_was_str:
res = all_embeddings.view(-1)
else:
res = all_embeddings

if self.convert_to_numpy:
return res.numpy()
else:
return res

def _encode_batch(self, text: list[str], prefix: str | None = None) -> torch.Tensor:
if prefix:
text = [prefix + t for t in text]

if self.add_eos:
text = self._add_eos_func(text)
if prefix:
text = [prefix + t for t in text]

encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(
self.model.device
)
model_output = self.model(**encoded_input)
last_hidden_states = model_output["last_hidden_state"]
features = {
"input_ids": encoded_input["input_ids"],
"attention_mask": encoded_input["attention_mask"],
"token_embeddings": last_hidden_states,
}
if "token_type_ids" in encoded_input:
features["token_type_ids"] = encoded_input["token_type_ids"]

encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(self.model.device)
model_output = self.model(**encoded_input)
last_hidden_states = model_output["last_hidden_state"]
features = {
"input_ids": encoded_input["input_ids"],
"attention_mask": encoded_input["attention_mask"],
"token_embeddings": last_hidden_states,
}
if "token_type_ids" in encoded_input:
features["token_type_ids"] = encoded_input["token_type_ids"]
if prefix:
features["prompt_length"] = self.tokenizer([prefix], return_tensors="pt")["input_ids"].shape[-1] - 1

if prefix:
features["prompt_length"] = self.tokenizer([prefix], return_tensors="pt")["input_ids"].shape[-1] - 1
# TODO: feature["token_weights_sum"]

# TODO: feature["token_weights_sum"]
with torch.no_grad():
sentence_embeddings = self.pooling.forward(features)["sentence_embedding"]

with torch.no_grad():
sentence_embeddings = self.pooling.forward(features)["sentence_embedding"]
if self.truncate_dim:
sentence_embeddings = sentence_embeddings[..., : self.truncate_dim]
if self.normalize_embeddings:
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)

if text_was_str:
sentence_embeddings = sentence_embeddings.view(-1)
return sentence_embeddings

def _encode_batch_distributed(self, text: list[str], prefix: str | None = None) -> torch.Tensor:
def _encode_distributed(self, text: list[str], prefix: str | None = None) -> torch.Tensor:
batch_gather = []
with self.distributed_state.split_between_processes(text) as t:
sentence_embeddings = self._encode_batch(t, prefix)
batch_gather.extend(sentence_embeddings.to("cpu"))
sentence_embeddings = self._encode(t, prefix)
batch_gather.extend(torch.Tensor(sentence_embeddings).to("cpu"))

batch_embeddings = gather_object(batch_gather)
return torch.stack(batch_embeddings)
Expand Down

0 comments on commit 5127941

Please sign in to comment.