diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..7c94708 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,27 @@ +on: + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + push: + branches: [main] + + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout github repo + uses: actions/checkout@v4 + with: + lfs: true + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.9 + + - name: Install package + run: pip install '.[test]' + + - name: Run tests + run: pytest tests diff --git a/README.md b/README.md index 7a3f85f..3bf3e28 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,11 @@ The package implements metrics designed to work well with lyrics formatted accor - Line breaks - Section breaks (i.e. double line breaks) +Under the hood, the text is pre-processed using the [`sacremoses`](https://github.com/hplt-project/sacremoses) tokenizer and punctuation normalizer. +Note that apostrophes and single quotes are never treated as quotation marks, but as part of a word, marking an elision or a contraction. +For writing systems that do not use spaces to separate words (Chinese, Japanese, Thai, Lao, Burmese, …), each character is considered as a separate word, as per [Radford et al. (2022)](https://arxiv.org/abs/2212.04356). +See the [test cases](./tests/test_tokenizer.py) for examples of how different languages are tokenized. + ## Usage Install the package with `pip install alt-eval`. diff --git a/pyproject.toml b/pyproject.toml index 71c3341..e225e86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "jiwer>=3.0.3", "python-iso639>=2023.6.15", "regex>=2023.8.8", - "sacremoses>=0.0.53", + "sacremoses==0.0.53", ] classifiers = [ "License :: OSI Approved :: MIT License", @@ -26,6 +26,11 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] +[project.optional-dependencies] +test = [ + "pytest>=7.3.1", +] + [build-system] requires = ["setuptools"] diff --git a/src/alt_eval/metrics.py b/src/alt_eval/metrics.py index 68a3eb4..1edf6de 100644 --- a/src/alt_eval/metrics.py +++ b/src/alt_eval/metrics.py @@ -123,6 +123,8 @@ def compute_word_metrics( results = { "WER": wo.wer, + "MER": wo.mer, + "WIL": wo.wil, "ER_case": error_counts["case"] / total_len, } if visualize_errors: @@ -194,7 +196,7 @@ def compute_metrics( """ if isinstance(languages, str): languages = [languages] * len(references) - languages = [lg if lg == "cjk" else iso639.Language.match(lg).part1 for lg in languages] + languages = [iso639.Language.match(lg).part1 for lg in languages] tokenizer = LyricsTokenizer() tokens_ref, tokens_hyp = [], [] diff --git a/src/alt_eval/tokenizer.py b/src/alt_eval/tokenizer.py index 67c93a2..4c82364 100644 --- a/src/alt_eval/tokenizer.py +++ b/src/alt_eval/tokenizer.py @@ -1,5 +1,6 @@ import copy from dataclasses import dataclass, field +import functools import unicodedata import regex as re @@ -59,6 +60,30 @@ def tokens_as_words(tokens: list[Token]) -> list[Token]: return result +# fmt: off +UNICODE_SCRIPTS = [ + "Adlm", "Aghb", "Ahom", "Arab", "Armi", "Armn", "Avst", "Bali", "Bamu", "Bass", "Batk", "Beng", + "Bhks", "Bopo", "Brah", "Brai", "Bugi", "Buhd", "Cakm", "Cans", "Cari", "Cham", "Cher", "Chrs", + "Copt", "Cpmn", "Cprt", "Cyrl", "Deva", "Diak", "Dogr", "Dsrt", "Dupl", "Egyp", "Elba", "Elym", + "Ethi", "Geor", "Glag", "Gong", "Gonm", "Goth", "Gran", "Grek", "Gujr", "Guru", "Hang", "Hani", + "Hano", "Hatr", "Hebr", "Hira", "Hluw", "Hmng", "Hmnp", "Hung", "Ital", "Java", "Kali", "Kana", + "Kawi", "Khar", "Khmr", "Khoj", "Kits", "Knda", "Kthi", "Lana", "Laoo", "Latn", "Lepc", "Limb", + "Lina", "Linb", "Lisu", "Lyci", "Lydi", "Mahj", "Maka", "Mand", "Mani", "Marc", "Medf", "Mend", + "Merc", "Mero", "Mlym", "Modi", "Mong", "Mroo", "Mtei", "Mult", "Mymr", "Nagm", "Nand", "Narb", + "Nbat", "Newa", "Nkoo", "Nshu", "Ogam", "Olck", "Orkh", "Orya", "Osge", "Osma", "Ougr", "Palm", + "Pauc", "Perm", "Phag", "Phli", "Phlp", "Phnx", "Plrd", "Prti", "Rjng", "Rohg", "Runr", "Samr", + "Sarb", "Saur", "Sgnw", "Shaw", "Shrd", "Sidd", "Sind", "Sinh", "Sogd", "Sogo", "Sora", "Soyo", + "Sund", "Sylo", "Syrc", "Tagb", "Takr", "Tale", "Talu", "Taml", "Tang", "Tavt", "Telu", "Tfng", + "Tglg", "Thaa", "Thai", "Tibt", "Tirh", "Tnsa", "Toto", "Ugar", "Vaii", "Vith", "Wara", "Wcho", + "Xpeo", "Xsux", "Yezi", "Yiii", "Zanb", +] +UNICODE_SCRIPTS_NO_SPACES = [ + "Egyp", "Hani", "Hira", "Hluw", "Lina", "Linb", "Xsux", "Kana", "Khmr", "Laoo", "Mymr", "Phag", + "Lana", "Thai", "Tibt", +] +# fmt: on + + class LyricsTokenizer: """A Moses-based tokenizer for lyrics. @@ -80,24 +105,34 @@ def __init__(self) -> None: r"(?P)(?P's)\b|\b(?Pwie|für)(?P'n)\b", flags=re.IGNORECASE ) + # A regex to match the boundary between two letters from two different scripts, or between a + # number and a letter from a script that does not use spaces between words. + self._different_scripts_re = re.compile( + r"|".join( + [rf"(?<=[\p{{L}}&&\p{{{s}}}])(?=[\p{{L}}--\p{{{s}}}])" for s in UNICODE_SCRIPTS] + + [rf"(?<=[\p{{L}}&&\p{{{s}}}])(?=[0-9])" for s in UNICODE_SCRIPTS_NO_SPACES] + + [rf"(?<=[0-9])(?=[\p{{L}}&&\p{{{s}}}])" for s in UNICODE_SCRIPTS_NO_SPACES] + ), + flags=re.VERSION1, + ) + + # A regex to match a character in a script that does not use spaces between words. + self._no_spaces_re = re.compile( + r"(" + r"|".join([rf"\p{{{s}}}" for s in UNICODE_SCRIPTS_NO_SPACES]) + r")", + flags=re.VERSION1, + ) + def __call__(self, text: str, language: str = "en") -> list[Token]: """ Tokenize the given text. Args: text: A string to tokenize. - language: A language code supported by `sacremoses`: either an ISO 639-1 language code, - or "cjk" for Chinese, Japanese and Korean. + language: An ISO 639-1 language code. Returns: A list of `Token` objects. """ - if language not in self._tokenizers: - self._tokenizers[language] = MosesTokenizer(lang=language) - self._punct_normalizers[language] = MosesPunctNormalizer(lang=language) - tokenizer = self._tokenizers[language] - punct_normalizer = self._punct_normalizers[language] - text = self._non_text_re.sub(" ", text) text = unicodedata.normalize("NFC", text) text = text.rstrip("\n") @@ -111,42 +146,63 @@ def __call__(self, text: str, language: str = "en") -> list[Token]: if line.count("\n") >= 2: result.append("\n\n") elif line.strip(): - # Ensure the line ends with punctuation to make the tokenizer treat it as - # a sentence - remove_last = False - if not self._end_punctuation_re.search(line): - remove_last = True - line += " ." - - line = punct_normalizer.normalize(line) - - if language in ["en", "fr", "it"]: - # Protect apostrophes at word boundaries to prevent the tokenizer from - # interpreting them as quotes - line = self._word_boundary_apos_re.sub("@@apos@@", line) - else: - # For languages where the tokenizer doesn't handle apostrophes within words, - # protect all apostrophes - line = line.replace("'", "@@apos@@") - - line = tokenizer.tokenize( - line.strip(), - return_str=True, - escape=False, - aggressive_dash_splits=True, - protected_patterns=[r"\*+", r"@@apos@@"], - ) - - if remove_last: - assert line.endswith(" ."), line - line = line[:-2] - - # Post-process apostrophes - line = line.replace("@@apos@@", "'") - if language == "de": - # Split contractions - line = self._contraction_de_re.sub(r"\g \g", line) + # Tokenize using sacremoses + line = self._tokenize_moses(line, language) + + # In languages that do not use spaces to separate words, treat each + # character as a separate word + line = self._no_spaces_re.sub(r" \1 ", line) + + # Insert spaces between characters from different scripts + line = self._different_scripts_re.sub(" ", line) result.extend(line.strip().split()) return to_rich_tokens(result) + + @functools.lru_cache(maxsize=200) + def _get_moses_tokenizer(self, language: str) -> MosesTokenizer: + return MosesTokenizer(lang=language) + + @functools.lru_cache(maxsize=200) + def _get_moses_punct_normalizer(self, language: str) -> MosesPunctNormalizer: + return MosesPunctNormalizer(lang=language) + + def _tokenize_moses(self, line: str, language: str) -> str: + # Ensure the line ends with punctuation to make the tokenizer treat it as + # a sentence + remove_last = False + if not self._end_punctuation_re.search(line): + remove_last = True + line += " ." + + line = self._get_moses_punct_normalizer(language).normalize(line) + + if language in ["en", "fr", "it"]: + # Protect apostrophes at word boundaries to prevent the tokenizer from + # interpreting them as quotes + line = self._word_boundary_apos_re.sub("@@apos@@", line) + else: + # For languages where the tokenizer doesn't handle apostrophes within words, + # protect all apostrophes + line = line.replace("'", "@@apos@@") + + line = self._get_moses_tokenizer(language).tokenize( + line.strip(), + return_str=True, + escape=False, + aggressive_dash_splits=True, + protected_patterns=[r"\*+", r"@@apos@@"], + ) + + if remove_last: + assert line.endswith(" ."), line + line = line[:-2] + + # Post-process apostrophes + line = line.replace("@@apos@@", "'") + if language == "de": + # Split contractions + line = self._contraction_de_re.sub(r"\g \g", line) + + return line diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 0000000..c39eb61 --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,56 @@ +import pytest + +from alt_eval.tokenizer import LyricsTokenizer + + +# fmt: off +@pytest.mark.parametrize( + "language, text, expected_tokens", + [ + ( + "en", + "I ain't got nothin' but the blues", + ["I", "ain", "'t", "got", "nothin'", "but", "the", "blues"], + ), + ( + "en", + "It'll be fun (ha!)", + ["It", "'ll", "be", "fun", "(", "ha", "!", ")"] + ), + ( + "en", + "Just like 2Pac", + ["Just", "like", "2Pac"], + ), + ( + "de", + "Sei's Melancholie", + ["Sei", "'s", "Melancholie"] + ), + ( + "de", + "Könnt' ich dir Schmerz erspar'n", + ["Könnt'", "ich", "dir", "Schmerz", "erspar'n"], + ), + ( + "fr", + "T'avais fait l'amour deux fois sans penser qu'avec cette fille-là", + ["T'", "avais", "fait", "l'", "amour", "deux", "fois", "sans", "penser", "qu'", "avec", "cette", "fille", "-", "là"], + ), + ( + "ja", + "私は日本語を話せません(ラララ)", + ["私", "は", "日", "本", "語", "を", "話", "せ", "ま", "せ", "ん", "(", "ラ", "ラ", "ラ", ")"], + ), + ( + "zh", + "我不会说中文。(哈哈)", + ["我", "不", "会", "说", "中", "文", "。", "(", "哈", "哈", ")"], + ) + ], +) +# fmt: on +def test_lyrics_tokenizer(language, text, expected_tokens): + tokenizer = LyricsTokenizer() + tokens = [t.text for t in tokenizer(text, language=language)] + assert tokens == expected_tokens