Skip to content

Commit

Permalink
Fix tones list from g2p function being initliazed to zeroes and adjus…
Browse files Browse the repository at this point in the history
…t test case
  • Loading branch information
Jim Eric Skogman authored and Jim Eric Skogman committed May 16, 2024
1 parent 11c55ef commit 3004182
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 35 deletions.
6 changes: 3 additions & 3 deletions melo/text/test_thai.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def test_g2p():
text = "ฉันรักเมืองไทย"
normalized_text = text_normalize(text)
phones, tones, word2ph = g2p(normalized_text)
assert phones == ['_', 't͡ɕʰ', 'a', 'n', '˩˩˦', 'r', 'a', 'k̚', '˦˥', 'm', 'ɯa̯', 'ŋ', '˧', 'tʰ', 'aj', '˧', '.', 'j', 'a', '˦˥', '.', '_']
assert tones == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
assert word2ph == [1, 7, 7, 6, 1]
assert phones == ['_', 't͡ɕʰ', 'a', 'n', '', 'r', 'a', 'k̚', '', 'm', 'ɯa̯', 'ŋ', '', 'tʰ', 'aj', '', '.', 'j', 'a', '', '.', '_']
assert tones == [0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 3, 0, 0, 3, 0, 0, 0, 5, 0]
assert word2ph == [1, 0, 8, 12, 1]

def test_get_bert_feature():
text = "ฉันเข้าใจคุณค่าของงานของฉันและความหมายของสิ่งที่ฟอนเทนทำเพื่อคนทั่วไปเป็นอย่างดี"
Expand Down
90 changes: 59 additions & 31 deletions melo/text/thai.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,47 +72,75 @@ def distribute_phone(n_phone, n_word):
def g2p(norm_text):
tokenized = tokenizer.tokenize(norm_text)
phs = []
ph_groups = []
current_group = [] # Track the current group of tokens
word2ph = []
current_word = []
current_phonemes = []

for token in tokenized:
if token.startswith("▁"): # Start of a new word
if current_word:
word_phonemes = " ".join(current_phonemes)
phs.extend(word_phonemes.split())
word2ph.append(len(current_phonemes))
current_word = []
current_phonemes = []
current_word.append(token.replace("▁", ""))
else:
current_word.append(token)

for t in tokenized:
if t in punctuation or t in pu_symbols: # Check if the token is a special character
phs.append(t)
if token in punctuation or token in pu_symbols:
phs.append(token)
word2ph.append(1)
else:
if t.startswith("▁"): # Start of a new word or phrase
if current_group: # Append current group to ph_groups if not empty
ph_groups.append(current_group)
current_group = [] # Reset current_group for the new word or phrase
current_group.append(t.replace("▁", "")) # Add token to current_group

if current_group: # Append the last group if not empty
ph_groups.append(current_group)

for group in ph_groups:
text = "".join(group) # Concatenate tokens in the group to form the word or phrase
if text == '[UNK]': # handle special cases like unknown tokens ("[UNK]")
phs.append('_')
word2ph.append(1)
continue
phonemes = thai_text_to_phonemes(text)
phone_len = len(phonemes.split())
word_len = len(group)
aaa = distribute_phone(phone_len, word_len)
assert len(aaa) == word_len
word2ph.extend(aaa)
phs.extend(phonemes.split())

phones = ["_"] + phs + ["_"]
tones = [0 for _ in phones]
word2ph = [1] + word2ph + [1]
phonemes = thai_text_to_phonemes(token.replace("▁", ""))
current_phonemes.extend(phonemes.split())

if current_word:
word_phonemes = " ".join(current_phonemes)
phs.extend(word_phonemes.split())
word2ph.append(len(current_phonemes))

# Distribute phonemes to match the number of tokens
distributed_word2ph = []
for i, group in enumerate(tokenized):
if group.startswith("▁"):
group = group.replace("▁", "")
if group in punctuation or group in pu_symbols:
distributed_word2ph.append(1)
else:
phonemes = thai_text_to_phonemes(group)
distributed_word2ph.append(len(phonemes.split()))

tone_markers = ['˥', '˦', '˧', '˨', '˩']
phones = ["_"] + [re.sub(f'[{"".join(tone_markers)}]', '', p) for p in phs] + ["_"] # Remove tone markers from phones
tones = extract_tones(phs) # Extract tones from the original phs list
word2ph = [1] + distributed_word2ph + [1]

assert len(word2ph) == len(tokenized) + 2

return phones, tones, word2ph


def extract_tones(phones):
tones = []
tone_map = {
"˥": 5, # High tone
"˦": 4, # Rising tone
"˧": 3, # Mid tone
"˨": 2, # Falling tone
"˩": 1, # Low tone
}

for phone in phones:
tone = 0
for marker, value in tone_map.items():
if marker in phone:
tone = value
break
tones.append(tone)

return tones

def get_bert_feature(text, word2ph, device='cuda', model_id='airesearch/wangchanberta-base-att-spm-uncased'):
from . import thai_bert
return thai_bert.get_bert_feature(text, word2ph, device=device, model_id=model_id)
Expand Down
2 changes: 1 addition & 1 deletion melo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None):
if language_str == "ZH":
bert = bert
ja_bert = torch.zeros(768, len(phone))
elif language_str in ["JP", "EN", "ZH_MIX_EN", 'KR', 'SP', 'ES', 'FR', 'DE', 'RU']:
elif language_str in ["JP", "EN", "ZH_MIX_EN", 'KR', 'SP', 'ES', 'FR', 'DE', 'RU', 'TH']:
ja_bert = bert
bert = torch.zeros(1024, len(phone))
else:
Expand Down

0 comments on commit 3004182

Please sign in to comment.