Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🇹🇭 Add Thai language support #15

Merged
merged 6 commits into from
Jan 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion meloplus/data_utils.py
Original file line number Diff line number Diff line change
@@ -157,7 +157,7 @@ def get_text(self, text, word2ph, phone, tone, language_str, wav_path):
if language_str in ["ZH"]:
bert = bert
ja_bert = torch.zeros(768, len(phone))
elif language_str in ["JP", "EN", "ZH_MIX_EN", "KR", 'SP', 'ES', 'FR', 'DE', 'RU']:
elif language_str in ["JP", "EN", "ZH_MIX_EN", "KR", 'SP', 'ES', 'FR', 'DE', 'RU', 'TH']:
ja_bert = bert
bert = torch.zeros(1024, len(phone))
else:
35 changes: 35 additions & 0 deletions meloplus/models.py
Original file line number Diff line number Diff line change
@@ -762,6 +762,24 @@ def __init__(
num_languages=num_languages,
num_tones=num_tones,
)
self.enc_p = TextEncoder(
219, # Initialize with the original symbol size
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
gin_channels=self.enc_gin_channels,
num_languages=num_languages,
num_tones=num_tones,
)
if n_vocab != 219:
old_embeddings = self.enc_p.emb
new_num_tokens = n_vocab
self.enc_p.emb = self.get_resized_embeddings(old_embeddings, new_num_tokens)

self.dec = Generator(
inter_channels,
resblock,
@@ -812,6 +830,23 @@ def __init__(
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels, layernorm=norm_refenc)
self.use_vc = use_vc

def get_resized_embeddings(self, old_embeddings, new_num_tokens):
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens:
return old_embeddings

if not isinstance(old_embeddings, nn.Embedding):
raise TypeError(
f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. "
f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Embedding}."
)

new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(
device=old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)
new_embeddings.weight.data[:old_num_tokens, :] = old_embeddings.weight.data[:old_num_tokens, :]

return new_embeddings

def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert, ja_bert):
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
4 changes: 3 additions & 1 deletion meloplus/text/__init__.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@ def get_bert(norm_text, word2ph, language, device):
from .japanese_bert import get_bert_feature as jp_bert
from .korean import get_bert_feature as kr_bert
from .spanish_bert import get_bert_feature as sp_bert
from .thai import get_bert_feature as th_bert

lang_bert_func_map = {
"ZH": zh_bert,
@@ -36,7 +37,8 @@ def get_bert(norm_text, word2ph, language, device):
'FR': fr_bert,
'SP': sp_bert,
'ES': sp_bert,
"KR": kr_bert
"KR": kr_bert,
"TH": th_bert
}
bert = lang_bert_func_map[language](norm_text, word2ph, device)
return bert
5 changes: 3 additions & 2 deletions meloplus/text/cleaner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy

from . import chinese, chinese_mix, cleaned_text_to_sequence, english, french, japanese, korean, spanish
from . import chinese, chinese_mix, cleaned_text_to_sequence, english, french, japanese, korean, spanish, thai

language_module_map = {
"ZH": chinese,
@@ -10,7 +10,8 @@
'KR': korean,
'FR': french,
'SP': spanish,
'ES': spanish
'ES': spanish,
'TH': thai,
}


23 changes: 20 additions & 3 deletions meloplus/text/symbols.py
Original file line number Diff line number Diff line change
@@ -192,19 +192,35 @@
ru_symbols = ["ɭ", "ʲ", "ɕ", "\"", "ɵ", "^", "ɬ"]
num_ru_tones = 1

# Thai
th_symbols = [
"ก", "ข", "ฃ", "ค", "ฅ", "ฆ", "ง", "จ", "ฉ", "ช", "ซ", "ฌ", "ญ", "ฎ", "ฏ", "ฐ", "ฑ", "ฒ", "ณ", "ด", "ต",
"ถ", "ท", "ธ", "น", "บ", "ป", "ผ", "ฝ", "พ", "ฟ", "ภ", "ม", "ย", "ร", "ล", "ว", "ศ", "ษ", "ส", "ห", "ฬ",
"อ", "ฮ", "ะ", "ั", "า", "ำ", "ิ", "ี", "ึ", "ื", "ุ", "ู", "เ", "แ", "โ", "ใ", "ไ", "ๅ", "็", "่", "้",
"์", "๑", "๒", "๓", "๔", "๕", "๖", "๗", "๘", "๙", "๐", 'kʰ', 'aj', '˧', 'pʰ', 'uː', '˥˩', 'p̚', '˦˥',
't͡ɕʰ', 'ʔ', 'aː', '˩˩˦', 'tʰ', 'eː', 'k̚', '˨˩', 't͡ɕ', 'aʔ', 'iː', 'ɤː', 't̚', 'ɛːw', 'ɯː', 'ia̯',
'ua̯', 'ɛː', 'aːj', 'ua̯j', 'ɤːj', 'ɔː', 'rɯ', 'a̯', 'ɤ', 'oː', 'aːw', 'ɔːj', 'oʔ', 'lɯ', 'ɯa̯', 'ɛʔ',
'ia̯w', 'lɯː', 'rɯː', 'oːj', 'ɔʔ', 'ๆ', 'ɔj', ';', 'ew', 'ɤʔ', 'iw', '๊', '”', 'eʔ', 'uj', '“', '๋', 'ฤ',
'ɨ', 'eːw', 'a̯j', 'ɛw', '‘', '’', '—', 'ia̯ʔ', 'ํ', 'p̚', 'ɨ', 'ŋ', '˥', 'd', 'ʰ', 'ɗ', 'kʰ', 'a', '.',
'ʔ', 'j', 'b', 'ɛ', 'ǐ', 'i', 'ᵊ', 'f', 'h', 'rɯ', 'ì', '̯', '̚', '˦', 's', 'ɯ', 'k', 'u', '˩', 'ɕ', 'e',
'˨', 'r', 'ɓ', 'ɤ', 'cʰ', 'æ', 'p', 'm', 'ɔ', 'o', 'w', 't', 'c', '̌', 'à', 'ː', ' ', 'n', 'ia̯', '˧',
'l', 'ə', 'æː', 'i̯', '▁', 'am'
]
num_th_tones = 5

# combine all symbols
normal_symbols = sorted(
set(
zh_symbols + ja_symbols + en_symbols + kr_symbols + es_symbols + fr_symbols + de_symbols +
ru_symbols))
symbols = [pad] + normal_symbols + pu_symbols
symbols = [pad] + normal_symbols + pu_symbols + th_symbols
sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]

# combine all tones
num_tones = num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones + num_fr_tones + num_de_tones + num_ru_tones
num_tones = num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones + num_fr_tones + num_de_tones + num_ru_tones + num_th_tones

# language maps
language_id_map = {"ZH": 0, "JP": 1, "EN": 2, "ZH_MIX_EN": 3, 'KR': 4, 'ES': 5, 'SP': 5, 'FR': 6}
language_id_map = {"ZH": 0, "JP": 1, "EN": 2, "ZH_MIX_EN": 3, 'KR': 4, 'ES': 5, 'SP': 5, 'FR': 6, 'TH': 7}
num_languages = len(language_id_map.keys())

language_tone_start_map = {
@@ -216,6 +232,7 @@
"ES": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones,
"SP": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones,
"FR": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones,
"TH": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones + num_fr_tones,
}

if __name__ == "__main__":
150 changes: 150 additions & 0 deletions meloplus/text/test_thai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import re
import pytest
import torch
from meloplus.text.thai import (
normalize,
word_tokenize,
thai_text_to_phonemes,
text_normalize,
g2p,
get_bert_feature,
)
from meloplus.text.korean import (
text_normalize as k_text_normalize,
get_bert_feature as k_get_bert_feature,
g2p as k_g2p,
)


def test_normalize():
text = " ข้อความ ภาษา ไทย 123 ABC "
normalized_text = normalize(text)
assert normalized_text == "ข้อความ ภาษา ไทย หนึ่งร้อยยี่สิบสาม เอบีซี"


# def test_word_tokenize():
# text = "ฉันเข้าใจคุณค่าของงานของฉันและความหมายของสิ่งที่ฟอนเทนทำเพื่อคนทั่วไปเป็นอย่างดี"
# tokenized_text = word_tokenize(text, engine="newmm")
# assert tokenized_text == ['ฉัน', 'เข้าใจ', 'คุณค่า', 'ของ', 'งาน', 'ของ', 'ฉัน', 'และ', 'ความหมาย', 'ของ', 'สิ่ง', 'ที่', 'ฟอน', 'เท', 'น', 'ทำ', 'เพื่อ', 'คน', 'ทั่วไป', 'เป็น', 'อย่าง', 'ดี']

# def test_thai_text_to_phonemes():
# text = "สวัสดีครับ"
# phonemes = thai_text_to_phonemes(text)
# assert phonemes == "s a ˨˩ . w a t̚ ˨˩ . d iː ˧ kʰ r a p̚ ˦˥"

# def test_g2p():
# text = "กงล้อ"
# normalized_text = text_normalize(text)
# phones, tones, word2ph = g2p(normalized_text)

# print(f"Phones: {phones}")
# print(f"Tones: {tones}")
# print(f"Word2ph: {word2ph}")

# expected_phones = ['_', 't͡ɕʰ', 'a', 'n', '', 'r', 'a', 'k̚', '', 'm', 'ɯa̯', 'ŋ', '', 'tʰ', 'aj', '', '.', 'j', 'a', '', '.', '_']
# expected_tones = [0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 3, 0, 0, 3, 0, 0, 0, 5, 0]
# expected_word2ph = [1, 0, 1, 1, 1]

# assert phones == expected_phones
# assert tones == expected_tones
# assert word2ph == expected_word2ph

# # Additional test case
# text = "สวัสดี ประเทศไทย"
# normalized_text = text_normalize(text)
# phones, tones, word2ph = g2p(normalized_text)

# print(f"Phones: {phones}")
# print(f"Tones: {tones}")
# print(f"Word2ph: {word2ph}")

# expected_phones = ['_', 's', 'a', 'w', 'a', 't̚', '', 'd', 'iː', '', 'p', 'r', 'a', '', 'tʰ', 'eː', 't̚', '', 'tʰ', 'aj', '', '.', 'j', 'a', '', '.', '_']
# expected_tones = [0, 0, 0, 0, 4, 2, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 5, 0]
# expected_word2ph = [1, 6, 10, 14, 18, 22, 26, 1]

# assert phones == expected_phones
# assert tones == expected_tones
# assert word2ph == expected_word2ph

# def test_get_bert_feature():
# text = "ฉันเข้าใจคุณค่าของงานของฉันและความหมายของสิ่งที่ฟอนเทนทำเพื่อคนทั่วไปเป็นอย่างดี"
# normalized_text = text_normalize(text)
# phones, tones, word2ph = g2p(normalized_text)

# bert_features = get_bert_feature(normalized_text, word2ph, device='cpu')

# assert isinstance(bert_features, torch.Tensor), "bert_features should be a torch.Tensor"
# assert bert_features.shape[0] == 768, f"Expected bert_features.shape[0] to be 768, but got {bert_features.shape[0]}"

# # Modify the assertion to check the number of phones instead of the length of word2ph
# num_phones = sum(word2ph)
# assert bert_features.shape[1] == num_phones, f"Expected bert_features.shape[1] to be {num_phones}, but got {bert_features.shape[1]}"

# # Additional assertions to check the values of bert_features
# assert not torch.isnan(bert_features).any(), "bert_features should not contain any NaN values"
# assert not torch.isinf(bert_features).any(), "bert_features should not contain any infinity values"


def extract_word_and_phonemes(line):
parts = line.strip().split("\t")
if len(parts) == 2:
word, phonemes = parts
return word, phonemes.split()
return None


def test_g2p():
# Test case for the word "กงล้อ"
text = "กงล้อ"
normalized_text = text_normalize(text)
phones, tones, word2ph = g2p(normalized_text)

# Expected output based on the wiktionary entry
expected_phones = ['_', 'k', 'o', 'ŋ', 'l', 'ɔː', '_']
expected_tones = [1, 2, 2, 2, 3, 3, 1]
expected_word2ph = [1, 3, 2, 1]

# Compare the actual output with the expected output
assert phones == expected_phones
assert tones == expected_tones
assert word2ph == expected_word2ph


def test_get_bert_feature_thai():
text = "กงล้อ"
normalized_text = text_normalize(text)
phones, tones, word2ph = g2p(normalized_text)
bert_features = get_bert_feature(normalized_text, word2ph, device='cpu')
assert isinstance(bert_features, torch.Tensor), "bert_features should be a torch.Tensor"
assert bert_features.shape[
0] == 768, f"Expected bert_features.shape[0] to be 768, but got {bert_features.shape[0]}"

# Modify the assertion to check the number of phones, excluding special characters
num_phones = sum(word2ph)
assert bert_features.shape[
1] == num_phones, f"Expected bert_features.shape[1] to be {num_phones}, but got {bert_features.shape[1]}"

assert not torch.isnan(bert_features).any(), "bert_features should not contain any NaN values"
assert not torch.isinf(bert_features).any(), "bert_features should not contain any infinity values"


# Compare with Korean
def test_get_bert_feature_korean():
text = "저는 제 일의 가치와 의미를 잘 알고 있습니다. 앞으로도 저는 제 일에 자부심을 갖고 살아갈 것입니다."
normalized_text = k_text_normalize(text)
phones, tones, word2ph = k_g2p(normalized_text)

bert_features = k_get_bert_feature(normalized_text, word2ph, device='cpu')

assert isinstance(bert_features, torch.Tensor), "bert_features should be a torch.Tensor"
assert bert_features.shape[
0] == 768, f"Expected bert_features.shape[0] to be 768, but got {bert_features.shape[0]}"

# Modify the assertion to check the number of phones instead of the length of word2ph
num_phones = sum(word2ph)
assert bert_features.shape[
1] == num_phones, f"Expected bert_features.shape[1] to be {num_phones}, but got {bert_features.shape[1]}"

# Additional assertions to check the values of bert_features
assert not torch.isnan(bert_features).any(), "bert_features should not contain any NaN values"
assert not torch.isinf(bert_features).any(), "bert_features should not contain any infinity values"
Loading