Skip to content

Commit

Permalink
Cross-Encoder use configurable batch size.
Browse files Browse the repository at this point in the history
Default is 32.
Can override with embedding batch_size in config or EMBEDDING_BATCH_SIZE env var.

Signed-off-by: Mark Sturdevant <[email protected]>
  • Loading branch information
markstur committed Sep 12, 2024
1 parent ac46993 commit 4e9c5aa
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion caikit_nlp/modules/text_embedding/crossencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ def __init__(
# model_max_length attribute availability might(?) vary by model/tokenizer
self.model_max_length = getattr(model.tokenizer, "model_max_length", None)

# Read config/env settings that are needed at run_* time.
embedding_cfg = get_config().get("embedding", {})

self.batch_size = embedding_cfg.get("batch_size", 32)
error.type_check("<NLP83501588E>", int, EMBEDDING_BATCH_SIZE=self.batch_size)
if self.batch_size <= 0:
self.batch_size = 32 # 0 or negative, use the default.

@classmethod
def load(
cls, model_path: Union[str, ModuleConfig], *args, **kwargs
Expand Down Expand Up @@ -324,7 +332,7 @@ def get_text(doc):
documents=doc_texts,
top_k=top_n,
return_documents=False,
batch_size=32,
batch_size=self.batch_size,
convert_to_numpy=True,
truncate_input_tokens=truncate_input_tokens,
)
Expand Down

0 comments on commit 4e9c5aa

Please sign in to comment.