From 5b1faee01dc034f7b08aa53aa992b0e935e403fa Mon Sep 17 00:00:00 2001 From: Jim Eric Skogman Date: Thu, 16 May 2024 18:29:21 +0700 Subject: [PATCH] Assign multiple phones based on the number of phones --- melo/text/test_thai.py | 32 +++++++++++++++-- melo/text/thai.py | 78 +++++++++++++++++++++++++++--------------- 2 files changed, 80 insertions(+), 30 deletions(-) diff --git a/melo/text/test_thai.py b/melo/text/test_thai.py index 1d56e867..b27daca0 100644 --- a/melo/text/test_thai.py +++ b/melo/text/test_thai.py @@ -33,9 +33,35 @@ def test_g2p(): text = "ฉันรักเมืองไทย" normalized_text = text_normalize(text) phones, tones, word2ph = g2p(normalized_text) - assert phones == ['_', 't͡ɕʰ', 'a', 'n', '', 'r', 'a', 'k̚', '', 'm', 'ɯa̯', 'ŋ', '', 'tʰ', 'aj', '', '.', 'j', 'a', '', '.', '_'] - assert tones == [0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 3, 0, 0, 3, 0, 0, 0, 5, 0] - assert word2ph == [1, 0, 8, 12, 1] + + print(f"Phones: {phones}") + print(f"Tones: {tones}") + print(f"Word2ph: {word2ph}") + + expected_phones = ['_', 't͡ɕʰ', 'a', 'n', '', 'r', 'a', 'k̚', '', 'm', 'ɯa̯', 'ŋ', '', 'tʰ', 'aj', '', '.', 'j', 'a', '', '.', '_'] + expected_tones = [0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 3, 0, 0, 3, 0, 0, 0, 5, 0] + expected_word2ph = [1, 0, 1, 1, 1] + + assert phones == expected_phones + assert tones == expected_tones + assert word2ph == expected_word2ph + + # Additional test case + text = "สวัสดี ประเทศไทย" + normalized_text = text_normalize(text) + phones, tones, word2ph = g2p(normalized_text) + + print(f"Phones: {phones}") + print(f"Tones: {tones}") + print(f"Word2ph: {word2ph}") + + expected_phones = ['_', 's', 'a', 'w', 'a', 't̚', '', 'd', 'iː', '', 'p', 'r', 'a', '', 'tʰ', 'eː', 't̚', '', 'tʰ', 'aj', '', '.', 'j', 'a', '', '.', '_'] + expected_tones = [0, 0, 0, 0, 4, 2, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 5, 0] + expected_word2ph = [1, 6, 10, 14, 18, 22, 26, 1] + + assert phones == expected_phones + assert tones == expected_tones + assert word2ph == expected_word2ph def test_get_bert_feature(): text = "ฉันเข้าใจคุณค่าของงานของฉันและความหมายของสิ่งที่ฟอนเทนทำเพื่อคนทั่วไปเป็นอย่างดี" diff --git a/melo/text/thai.py b/melo/text/thai.py index 92ac6734..45760d9c 100644 --- a/melo/text/thai.py +++ b/melo/text/thai.py @@ -43,15 +43,24 @@ def fn(m): thai_g2p_dict[word] = phonemes.split() def map_word_to_phonemes(word): - return thai_g2p_dict.get(word, list(word)) + print(f"Mapping word to phonemes: {word}") + phonemes = thai_g2p_dict.get(word, list(word)) + print(f"Phonemes for the word: {phonemes}") + return phonemes def thai_text_to_phonemes(text): + print(f"Original text: {text}") text = normalize(text) + print(f"Normalized text: {text}") words = word_tokenize(text, engine="newmm") + print(f"Tokenized words: {words}") phonemes = [] for word in words: + print(f"Processing word: {word}") word_phonemes = map_word_to_phonemes(word) + print(f"Word phonemes: {word_phonemes}") phonemes.extend(word_phonemes) + print(f"Final phonemes: {phonemes}") return " ".join(phonemes) def text_normalize(text): @@ -70,16 +79,20 @@ def distribute_phone(n_phone, n_word): tokenizer = AutoTokenizer.from_pretrained(model_id) def g2p(norm_text): + print(f"Normalized text: {norm_text}") tokenized = tokenizer.tokenize(norm_text) + print(f"Tokenized text: {tokenized}") phs = [] word2ph = [] current_word = [] current_phonemes = [] - for token in tokenized: - if token.startswith("▁"): # Start of a new word + print(f"Processing token: {token}") + if token.startswith("▁"): + print("Start of a new word") if current_word: word_phonemes = " ".join(current_phonemes) + print(f"Word phonemes: {word_phonemes}") phs.extend(word_phonemes.split()) word2ph.append(len(current_phonemes)) current_word = [] @@ -87,41 +100,50 @@ def g2p(norm_text): current_word.append(token.replace("▁", "")) else: current_word.append(token) - - if token in punctuation or token in pu_symbols: - phs.append(token) - word2ph.append(1) - else: - phonemes = thai_text_to_phonemes(token.replace("▁", "")) - current_phonemes.extend(phonemes.split()) + if token in punctuation or token in pu_symbols: + print(f"Punctuation or symbol: {token}") + phs.append(token) + word2ph.append(1) + else: + phonemes = thai_text_to_phonemes(token.replace("▁", "")) + print(f"Phonemes: {phonemes}") + current_phonemes.extend(phonemes.split()) if current_word: word_phonemes = " ".join(current_phonemes) + print(f"Word phonemes: {word_phonemes}") phs.extend(word_phonemes.split()) word2ph.append(len(current_phonemes)) - # Distribute phonemes to match the number of tokens + print(f"Final phs: {phs}") + print(f"Final word2ph: {word2ph}") + distributed_word2ph = [] for i, group in enumerate(tokenized): if group.startswith("▁"): group = group.replace("▁", "") - if group in punctuation or group in pu_symbols: - distributed_word2ph.append(1) + if group in punctuation or group in pu_symbols: + distributed_word2ph.append(1) + else: + phonemes = thai_text_to_phonemes(group) + distributed_word2ph.append(len(phonemes.split())) else: - phonemes = thai_text_to_phonemes(group) - distributed_word2ph.append(len(phonemes.split())) + distributed_word2ph.append(1) # Add 1 for spaces between words tone_markers = ['˥', '˦', '˧', '˨', '˩'] - phones = ["_"] + [re.sub(f'[{"".join(tone_markers)}]', '', p) for p in phs] + ["_"] # Remove tone markers from phones - tones = extract_tones(phs) # Extract tones from the original phs list + phones = ["_"] + [re.sub(f'[{"".join(tone_markers)}]', '', p) for p in phs] + ["_"] + print(f"Phones: {phones}") + + tones = extract_tones(phs) + print(f"Tones: {tones}") + word2ph = [1] + distributed_word2ph + [1] + print(f"Final word2ph: {word2ph}") assert len(word2ph) == len(tokenized) + 2 - return phones, tones, word2ph - -def extract_tones(phones): +def extract_tones(phs): tones = [] tone_map = { "˥": 5, # High tone @@ -130,17 +152,19 @@ def extract_tones(phones): "˨": 2, # Falling tone "˩": 1, # Low tone } - - for phone in phones: - tone = 0 + for ph in phs: + tone_found = False for marker, value in tone_map.items(): - if marker in phone: - tone = value + if marker in ph: + tones.append(value) + tone_found = True break - tones.append(tone) - + if not tone_found: + tones.append(0) return tones + + def get_bert_feature(text, word2ph, device='cuda', model_id='airesearch/wangchanberta-base-att-spm-uncased'): from . import thai_bert return thai_bert.get_bert_feature(text, word2ph, device=device, model_id=model_id)