From dd57e9ba0a0c0b7b445a6e8b73a277489cac1d07 Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Sat, 11 Jan 2025 17:01:04 -0500 Subject: [PATCH] Make attn_implementation configurable for huggingface models --- machine/jobs/huggingface/hugging_face_nmt_model_factory.py | 6 +++++- machine/jobs/settings.yaml | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py index 81d66c9..d1c9af9 100644 --- a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py +++ b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py @@ -65,7 +65,11 @@ def init(self) -> None: ) self._model = cast( PreTrainedModel, - AutoModelForSeq2SeqLM.from_pretrained(self._config.huggingface.parent_model_name, config=config), + AutoModelForSeq2SeqLM.from_pretrained( + self._config.huggingface.parent_model_name, + config=config, + attn_implementation=self._config.huggingface.attn_implementation, + ), ) def create_source_tokenizer_trainer(self, corpus: TextCorpus) -> Trainer: diff --git a/machine/jobs/settings.yaml b/machine/jobs/settings.yaml index 00d9517..663b942 100644 --- a/machine/jobs/settings.yaml +++ b/machine/jobs/settings.yaml @@ -27,6 +27,7 @@ default: tokenizer: add_unk_src_tokens: true add_unk_trg_tokens: true + attn_implementation: sdpa thot_mt: word_alignment_model_type: hmm tokenizer: latin