diff --git a/machine/translation/huggingface/hugging_face_nmt_engine.py b/machine/translation/huggingface/hugging_face_nmt_engine.py index 71ae2b1..9d43b28 100644 --- a/machine/translation/huggingface/hugging_face_nmt_engine.py +++ b/machine/translation/huggingface/hugging_face_nmt_engine.py @@ -4,7 +4,7 @@ import logging import re from math import exp, prod -from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union, cast +from typing import Iterable, List, Optional, Sequence, Tuple, Union, cast import torch # pyright: ignore[reportMissingImports] from sacremoses import MosesPunctNormalizer @@ -12,6 +12,7 @@ AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, + M2M100Tokenizer, NllbTokenizer, NllbTokenizerFast, PreTrainedModel, @@ -73,17 +74,23 @@ def __init__( self._pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: " else: additional_special_tokens = self._tokenizer.additional_special_tokens + if isinstance(self._tokenizer, M2M100Tokenizer): + src_lang_token = self._tokenizer.lang_code_to_token.get(src_lang) if src_lang is not None else None + tgt_lang_token = self._tokenizer.lang_code_to_token.get(tgt_lang) if tgt_lang is not None else None + else: + src_lang_token = src_lang + tgt_lang_token = tgt_lang if ( src_lang is not None - and src_lang not in cast(Any, self._tokenizer).lang_code_to_id - and src_lang not in additional_special_tokens + and src_lang_token not in self._tokenizer.added_tokens_encoder + and src_lang_token not in additional_special_tokens ): raise ValueError(f"The specified model does not support the language code '{src_lang}'") if ( tgt_lang is not None - and tgt_lang not in cast(Any, self._tokenizer).lang_code_to_id - and tgt_lang not in additional_special_tokens + and tgt_lang_token not in self._tokenizer.added_tokens_encoder + and tgt_lang_token not in additional_special_tokens ): raise ValueError(f"The specified model does not support the language code '{tgt_lang}'") diff --git a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py index 9b5a79c..2c80e1d 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py +++ b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py @@ -24,6 +24,7 @@ NllbTokenizer, NllbTokenizerFast, PreTrainedModel, + PreTrainedTokenizer, PreTrainedTokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments, @@ -218,30 +219,12 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any: if missing_tokens: tokenizer = add_tokens(tokenizer, missing_tokens) - def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str): - if lang_code in tokenizer.lang_code_to_id: - return - tokenizer.add_special_tokens( - {"additional_special_tokens": tokenizer.additional_special_tokens + [lang_code]} - ) - lang_id = tokenizer.convert_tokens_to_ids(lang_code) - tokenizer.lang_code_to_id[lang_code] = lang_id - - if isinstance(tokenizer, (MBart50Tokenizer, MBartTokenizer)): - tokenizer.id_to_lang_code[lang_id] = lang_code - tokenizer.fairseq_tokens_to_ids[lang_code] = lang_id - tokenizer.fairseq_ids_to_tokens[lang_id] = lang_code - elif isinstance(tokenizer, M2M100Tokenizer): - tokenizer.lang_code_to_token[lang_code] = lang_code - tokenizer.lang_token_to_id[lang_code] = lang_id - tokenizer.id_to_lang_token[lang_id] = lang_code - if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS): logger.info("Add new language codes as tokens") if self._src_lang is not None: - add_lang_code_to_tokenizer(tokenizer, self._src_lang) + _add_lang_code_to_tokenizer(tokenizer, self._src_lang) if self._tgt_lang is not None: - add_lang_code_to_tokenizer(tokenizer, self._tgt_lang) + _add_lang_code_to_tokenizer(tokenizer, self._tgt_lang) # We resize the embeddings only when necessary to avoid index errors. embedding_size = cast(Any, model.get_input_embeddings()).weight.shape[0] @@ -413,3 +396,29 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra if self._max_steps is None else ProgressStatus.from_step(state.global_step, self._max_steps) ) + + +def _add_lang_code_to_tokenizer(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], lang_code: str): + if isinstance(tokenizer, M2M100Tokenizer): + lang_token = "__" + lang_code + "__" + else: + lang_token = lang_code + + if lang_token in tokenizer.added_tokens_encoder: + return + + tokenizer.add_special_tokens( + {"additional_special_tokens": tokenizer.additional_special_tokens + [lang_token]} # type: ignore + ) + lang_id = cast(int, tokenizer.convert_tokens_to_ids(lang_token)) + + if isinstance(tokenizer, (MBart50Tokenizer, MBartTokenizer)): + tokenizer.lang_code_to_id[lang_code] = lang_id + tokenizer.id_to_lang_code[lang_id] = lang_code + tokenizer.fairseq_tokens_to_ids[lang_code] = lang_id + tokenizer.fairseq_ids_to_tokens[lang_id] = lang_code + elif isinstance(tokenizer, M2M100Tokenizer): + tokenizer.lang_code_to_id[lang_code] = lang_id + tokenizer.lang_code_to_token[lang_code] = lang_token + tokenizer.lang_token_to_id[lang_token] = lang_id + tokenizer.id_to_lang_token[lang_id] = lang_token diff --git a/tests/translation/huggingface/test_hugging_face_nmt_model_trainer.py b/tests/translation/huggingface/test_hugging_face_nmt_model_trainer.py index b843c5d..068a8c6 100644 --- a/tests/translation/huggingface/test_hugging_face_nmt_model_trainer.py +++ b/tests/translation/huggingface/test_hugging_face_nmt_model_trainer.py @@ -6,11 +6,23 @@ skip("skipping Hugging Face tests on MacOS", allow_module_level=True) from tempfile import TemporaryDirectory - -from transformers import PreTrainedTokenizerFast, Seq2SeqTrainingArguments +from typing import cast + +from transformers import ( + M2M100Tokenizer, + MBart50Tokenizer, + MBart50TokenizerFast, + MBartTokenizer, + MBartTokenizerFast, + NllbTokenizer, + NllbTokenizerFast, + PreTrainedTokenizerFast, + Seq2SeqTrainingArguments, +) from machine.corpora import DictionaryTextCorpus, MemoryText, TextRow from machine.translation.huggingface import HuggingFaceNmtEngine, HuggingFaceNmtModelTrainer +from machine.translation.huggingface.hugging_face_nmt_model_trainer import _add_lang_code_to_tokenizer def test_train_non_empty_corpus() -> None: @@ -142,10 +154,8 @@ def test_update_tokenizer_missing_char() -> None: "Ḻ, ḻ, Ṉ, ॽ, " + "‌ and " + "‍" + " are new characters" ) finetuned_result_nochar_composite = finetuned_engine_nochar.tokenizer.encode("Ḏ is a composite character") - normalized_result_nochar1 = finetuned_engine_nochar.tokenizer.backend_tokenizer.normalizer.normalize_str( - "‌ " - ) - normalized_result_nochar2 = finetuned_engine_nochar.tokenizer.backend_tokenizer.normalizer.normalize_str("‍") + norm_result_nochar1 = finetuned_engine_nochar.tokenizer.backend_tokenizer.normalizer.normalize_str("‌ ") + norm_result_nochar2 = finetuned_engine_nochar.tokenizer.backend_tokenizer.normalizer.normalize_str("‍") with HuggingFaceNmtModelTrainer( "hf-internal-testing/tiny-random-nllb", @@ -167,11 +177,11 @@ def test_update_tokenizer_missing_char() -> None: "Ḻ, ḻ, Ṉ, ॽ, " + "‌ and " + "‍" + " are new characters" ) finetuned_result_char_composite = finetuned_engine_char.tokenizer.encode("Ḏ is a composite character") - normalized_result_char1 = finetuned_engine_char.tokenizer.backend_tokenizer.normalizer.normalize_str("‌ ") - normalized_result_char2 = finetuned_engine_char.tokenizer.backend_tokenizer.normalizer.normalize_str("‍") + norm_result_char1 = finetuned_engine_char.tokenizer.backend_tokenizer.normalizer.normalize_str("‌ ") + norm_result_char2 = finetuned_engine_char.tokenizer.backend_tokenizer.normalizer.normalize_str("‍") - assert normalized_result_nochar1 != normalized_result_char1 - assert normalized_result_nochar2 != normalized_result_char2 + assert norm_result_nochar1 != norm_result_char1 + assert norm_result_nochar2 != norm_result_char2 assert finetuned_result_nochar != finetuned_result_char assert finetuned_result_nochar_composite != finetuned_result_char_composite @@ -467,5 +477,94 @@ def test_update_tokenizer_no_missing_char() -> None: assert finetuned_result_nochar == finetuned_result_char +def test_nllb_tokenizer_add_lang_code() -> None: + with TemporaryDirectory() as temp_dir: + tokenizer = cast(NllbTokenizer, NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")) + assert "new_lang" not in tokenizer.added_tokens_encoder + _add_lang_code_to_tokenizer(tokenizer, "new_lang") + assert "new_lang" in tokenizer.added_tokens_encoder + tokenizer.save_pretrained(temp_dir) + new_tokenizer = cast(NllbTokenizer, NllbTokenizer.from_pretrained(temp_dir)) + assert "new_lang" in new_tokenizer.added_tokens_encoder + return + + +def test_nllb_tokenizer_fast_add_lang_code() -> None: + with TemporaryDirectory() as temp_dir: + tokenizer = cast(NllbTokenizerFast, NllbTokenizerFast.from_pretrained("facebook/nllb-200-distilled-600M")) + assert "new_lang" not in tokenizer.added_tokens_encoder + _add_lang_code_to_tokenizer(tokenizer, "new_lang") + assert "new_lang" in tokenizer.added_tokens_encoder + tokenizer.save_pretrained(temp_dir) + new_tokenizer = cast(NllbTokenizerFast, NllbTokenizerFast.from_pretrained(temp_dir)) + assert "new_lang" in new_tokenizer.added_tokens_encoder + return + + +def test_mbart_tokenizer_add_lang_code() -> None: + with TemporaryDirectory() as temp_dir: + tokenizer = cast(MBartTokenizer, MBartTokenizer.from_pretrained("hf-internal-testing/tiny-random-nllb")) + assert "nl_NS" not in tokenizer.added_tokens_encoder + _add_lang_code_to_tokenizer(tokenizer, "nl_NS") + assert "nl_NS" in tokenizer.added_tokens_encoder + tokenizer.save_pretrained(temp_dir) + new_tokenizer = cast(MBartTokenizer, MBartTokenizer.from_pretrained(temp_dir)) + assert "nl_NS" in new_tokenizer.added_tokens_encoder + return + + +def test_mbart_tokenizer_fast_add_lang_code() -> None: + with TemporaryDirectory() as temp_dir: + tokenizer = cast(MBartTokenizerFast, MBartTokenizerFast.from_pretrained("hf-internal-testing/tiny-random-nllb")) + assert "nl_NS" not in tokenizer.added_tokens_encoder + _add_lang_code_to_tokenizer(tokenizer, "nl_NS") + assert "nl_NS" in tokenizer.added_tokens_encoder + tokenizer.save_pretrained(temp_dir) + new_tokenizer = cast(MBartTokenizerFast, MBartTokenizerFast.from_pretrained(temp_dir)) + assert "nl_NS" in new_tokenizer.added_tokens_encoder + return + + +def test_mbart_50_tokenizer_add_lang_code() -> None: + with TemporaryDirectory() as temp_dir: + tokenizer = cast(MBart50Tokenizer, MBart50Tokenizer.from_pretrained("hf-internal-testing/tiny-random-mbart50")) + assert "nl_NS" not in tokenizer.added_tokens_encoder + _add_lang_code_to_tokenizer(tokenizer, "nl_NS") + assert "nl_NS" in tokenizer.added_tokens_encoder + tokenizer.save_pretrained(temp_dir) + new_tokenizer = cast(MBart50Tokenizer, MBart50Tokenizer.from_pretrained(temp_dir)) + assert "nl_NS" in new_tokenizer.added_tokens_encoder + return + + +def test_mbart_50_tokenizer_fast_add_lang_code() -> None: + with TemporaryDirectory() as temp_dir: + tokenizer = cast( + MBart50TokenizerFast, MBart50TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-mbart50") + ) + assert "nl_NS" not in tokenizer.added_tokens_encoder + _add_lang_code_to_tokenizer(tokenizer, "nl_NS") + assert "nl_NS" in tokenizer.added_tokens_encoder + tokenizer.save_pretrained(temp_dir) + new_tokenizer = cast(MBart50TokenizerFast, MBart50TokenizerFast.from_pretrained(temp_dir)) + assert "nl_NS" in new_tokenizer.added_tokens_encoder + return + + +def test_m2m_100_tokenizer_add_lang_code() -> None: + with TemporaryDirectory() as temp_dir: + tokenizer = cast(M2M100Tokenizer, M2M100Tokenizer.from_pretrained("stas/tiny-m2m_100")) + assert "nc" not in tokenizer.lang_code_to_id + assert "__nc__" not in tokenizer.added_tokens_encoder + _add_lang_code_to_tokenizer(tokenizer, "nc") + assert "nc" in tokenizer.lang_code_to_id + assert "__nc__" in tokenizer.added_tokens_encoder + tokenizer.save_pretrained(temp_dir) + new_tokenizer = cast(M2M100Tokenizer, M2M100Tokenizer.from_pretrained(temp_dir)) + assert "nc" in tokenizer.lang_code_to_id + assert "__nc__" in new_tokenizer.added_tokens_encoder + return + + def _row(row_ref: int, text: str) -> TextRow: return TextRow("text1", row_ref, segment=[text])