Skip to content

Commit

Permalink
Assign multiple phones based on the number of phones
Browse files Browse the repository at this point in the history
  • Loading branch information
Jim Eric Skogman authored and Jim Eric Skogman committed May 16, 2024
1 parent 3004182 commit 5b1faee
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 30 deletions.
32 changes: 29 additions & 3 deletions melo/text/test_thai.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,35 @@ def test_g2p():
text = "ฉันรักเมืองไทย"
normalized_text = text_normalize(text)
phones, tones, word2ph = g2p(normalized_text)
assert phones == ['_', 't͡ɕʰ', 'a', 'n', '', 'r', 'a', 'k̚', '', 'm', 'ɯa̯', 'ŋ', '', 'tʰ', 'aj', '', '.', 'j', 'a', '', '.', '_']
assert tones == [0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 3, 0, 0, 3, 0, 0, 0, 5, 0]
assert word2ph == [1, 0, 8, 12, 1]

print(f"Phones: {phones}")
print(f"Tones: {tones}")
print(f"Word2ph: {word2ph}")

expected_phones = ['_', 't͡ɕʰ', 'a', 'n', '', 'r', 'a', 'k̚', '', 'm', 'ɯa̯', 'ŋ', '', 'tʰ', 'aj', '', '.', 'j', 'a', '', '.', '_']
expected_tones = [0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 3, 0, 0, 3, 0, 0, 0, 5, 0]
expected_word2ph = [1, 0, 1, 1, 1]

assert phones == expected_phones
assert tones == expected_tones
assert word2ph == expected_word2ph

# Additional test case
text = "สวัสดี ประเทศไทย"
normalized_text = text_normalize(text)
phones, tones, word2ph = g2p(normalized_text)

print(f"Phones: {phones}")
print(f"Tones: {tones}")
print(f"Word2ph: {word2ph}")

expected_phones = ['_', 's', 'a', 'w', 'a', 't̚', '', 'd', 'iː', '', 'p', 'r', 'a', '', 'tʰ', 'eː', 't̚', '', 'tʰ', 'aj', '', '.', 'j', 'a', '', '.', '_']
expected_tones = [0, 0, 0, 0, 4, 2, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 5, 0]
expected_word2ph = [1, 6, 10, 14, 18, 22, 26, 1]

assert phones == expected_phones
assert tones == expected_tones
assert word2ph == expected_word2ph

def test_get_bert_feature():
text = "ฉันเข้าใจคุณค่าของงานของฉันและความหมายของสิ่งที่ฟอนเทนทำเพื่อคนทั่วไปเป็นอย่างดี"
Expand Down
78 changes: 51 additions & 27 deletions melo/text/thai.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,24 @@ def fn(m):
thai_g2p_dict[word] = phonemes.split()

def map_word_to_phonemes(word):
return thai_g2p_dict.get(word, list(word))
print(f"Mapping word to phonemes: {word}")
phonemes = thai_g2p_dict.get(word, list(word))
print(f"Phonemes for the word: {phonemes}")
return phonemes

def thai_text_to_phonemes(text):
print(f"Original text: {text}")
text = normalize(text)
print(f"Normalized text: {text}")
words = word_tokenize(text, engine="newmm")
print(f"Tokenized words: {words}")
phonemes = []
for word in words:
print(f"Processing word: {word}")
word_phonemes = map_word_to_phonemes(word)
print(f"Word phonemes: {word_phonemes}")
phonemes.extend(word_phonemes)
print(f"Final phonemes: {phonemes}")
return " ".join(phonemes)

def text_normalize(text):
Expand All @@ -70,58 +79,71 @@ def distribute_phone(n_phone, n_word):
tokenizer = AutoTokenizer.from_pretrained(model_id)

def g2p(norm_text):
print(f"Normalized text: {norm_text}")
tokenized = tokenizer.tokenize(norm_text)
print(f"Tokenized text: {tokenized}")
phs = []
word2ph = []
current_word = []
current_phonemes = []

for token in tokenized:
if token.startswith("▁"): # Start of a new word
print(f"Processing token: {token}")
if token.startswith("▁"):
print("Start of a new word")
if current_word:
word_phonemes = " ".join(current_phonemes)
print(f"Word phonemes: {word_phonemes}")
phs.extend(word_phonemes.split())
word2ph.append(len(current_phonemes))
current_word = []
current_phonemes = []
current_word.append(token.replace("▁", ""))
else:
current_word.append(token)

if token in punctuation or token in pu_symbols:
phs.append(token)
word2ph.append(1)
else:
phonemes = thai_text_to_phonemes(token.replace("▁", ""))
current_phonemes.extend(phonemes.split())
if token in punctuation or token in pu_symbols:
print(f"Punctuation or symbol: {token}")
phs.append(token)
word2ph.append(1)
else:
phonemes = thai_text_to_phonemes(token.replace("▁", ""))
print(f"Phonemes: {phonemes}")
current_phonemes.extend(phonemes.split())

if current_word:
word_phonemes = " ".join(current_phonemes)
print(f"Word phonemes: {word_phonemes}")
phs.extend(word_phonemes.split())
word2ph.append(len(current_phonemes))

# Distribute phonemes to match the number of tokens
print(f"Final phs: {phs}")
print(f"Final word2ph: {word2ph}")

distributed_word2ph = []
for i, group in enumerate(tokenized):
if group.startswith("▁"):
group = group.replace("▁", "")
if group in punctuation or group in pu_symbols:
distributed_word2ph.append(1)
if group in punctuation or group in pu_symbols:
distributed_word2ph.append(1)
else:
phonemes = thai_text_to_phonemes(group)
distributed_word2ph.append(len(phonemes.split()))
else:
phonemes = thai_text_to_phonemes(group)
distributed_word2ph.append(len(phonemes.split()))
distributed_word2ph.append(1) # Add 1 for spaces between words

tone_markers = ['˥', '˦', '˧', '˨', '˩']
phones = ["_"] + [re.sub(f'[{"".join(tone_markers)}]', '', p) for p in phs] + ["_"] # Remove tone markers from phones
tones = extract_tones(phs) # Extract tones from the original phs list
phones = ["_"] + [re.sub(f'[{"".join(tone_markers)}]', '', p) for p in phs] + ["_"]
print(f"Phones: {phones}")

tones = extract_tones(phs)
print(f"Tones: {tones}")

word2ph = [1] + distributed_word2ph + [1]
print(f"Final word2ph: {word2ph}")

assert len(word2ph) == len(tokenized) + 2

return phones, tones, word2ph


def extract_tones(phones):
def extract_tones(phs):
tones = []
tone_map = {
"˥": 5, # High tone
Expand All @@ -130,17 +152,19 @@ def extract_tones(phones):
"˨": 2, # Falling tone
"˩": 1, # Low tone
}

for phone in phones:
tone = 0
for ph in phs:
tone_found = False
for marker, value in tone_map.items():
if marker in phone:
tone = value
if marker in ph:
tones.append(value)
tone_found = True
break
tones.append(tone)

if not tone_found:
tones.append(0)
return tones



def get_bert_feature(text, word2ph, device='cuda', model_id='airesearch/wangchanberta-base-att-spm-uncased'):
from . import thai_bert
return thai_bert.get_bert_feature(text, word2ph, device=device, model_id=model_id)
Expand Down

0 comments on commit 5b1faee

Please sign in to comment.