Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add demo ipynb #306

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ processed_data
data
model_ckpt
logs
*.ipynb
*.lst
source_audio
result
Expand Down
314 changes: 314 additions & 0 deletions models/tts/maskgct/maskgct_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"import librosa\n",
"import safetensors\n",
"from utils.util import load_config\n",
"\n",
"from models.codec.kmeans.repcodec_model import RepCodec\n",
"from models.tts.maskgct.maskgct_s2a import MaskGCT_S2A\n",
"from models.tts.maskgct.maskgct_t2s import MaskGCT_T2S\n",
"from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder\n",
"from transformers import Wav2Vec2BertModel\n",
"\n",
"from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import SeamlessM4TFeatureExtractor\n",
"processor = SeamlessM4TFeatureExtractor.from_pretrained(\"facebook/w2v-bert-2.0\")"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"def g2p_(text, language):\n",
" if language in [\"zh\", \"en\"]:\n",
" return chn_eng_g2p(text)\n",
" else:\n",
" return g2p(text, sentence=None, language=language)\n",
"\n",
"def build_t2s_model(cfg, device):\n",
" t2s_model = MaskGCT_T2S(cfg=cfg)\n",
" t2s_model.eval()\n",
" t2s_model.to(device)\n",
" return t2s_model\n",
"\n",
"def build_s2a_model(cfg, device):\n",
" soundstorm_model = MaskGCT_S2A(cfg=cfg)\n",
" soundstorm_model.eval()\n",
" soundstorm_model.to(device)\n",
" return soundstorm_model\n",
"\n",
"def build_semantic_model(device):\n",
" semantic_model = Wav2Vec2BertModel.from_pretrained(\"facebook/w2v-bert-2.0\")\n",
" semantic_model.eval()\n",
" semantic_model.to(device)\n",
" stat_mean_var = torch.load(\"./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt\")\n",
" semantic_mean = stat_mean_var[\"mean\"]\n",
" semantic_std = torch.sqrt(stat_mean_var[\"var\"])\n",
" semantic_mean = semantic_mean.to(device)\n",
" semantic_std = semantic_std.to(device)\n",
" return semantic_model, semantic_mean, semantic_std\n",
"\n",
"def build_semantic_codec(cfg, device):\n",
" semantic_codec = RepCodec(cfg=cfg)\n",
" semantic_codec.eval()\n",
" semantic_codec.to(device)\n",
" return semantic_codec\n",
"\n",
"def build_acoustic_codec(cfg, device):\n",
" codec_encoder = CodecEncoder(cfg=cfg.encoder)\n",
" codec_decoder = CodecDecoder(cfg=cfg.decoder)\n",
" codec_encoder.eval()\n",
" codec_decoder.eval()\n",
" codec_encoder.to(device)\n",
" codec_decoder.to(device)\n",
" return codec_encoder, codec_decoder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def extract_features(speech, processor):\n",
" inputs = processor(speech, sampling_rate=16000, return_tensors=\"pt\")\n",
" input_features = inputs[\"input_features\"][0]\n",
" attention_mask = inputs[\"attention_mask\"][0]\n",
" return input_features, attention_mask\n",
"\n",
"@torch.no_grad()\n",
"def extract_semantic_code(semantic_mean, semantic_std, input_features, attention_mask):\n",
" vq_emb = semantic_model(\n",
" input_features=input_features,\n",
" attention_mask=attention_mask,\n",
" output_hidden_states=True,\n",
" )\n",
" feat = vq_emb.hidden_states[17] # (B, T, C)\n",
" feat = (feat - semantic_mean.to(feat)) / semantic_std.to(feat)\n",
"\n",
" semantic_code, rec_feat = semantic_codec.quantize(feat) # (B, T)\n",
" return semantic_code, rec_feat\n",
"\n",
"@torch.no_grad()\n",
"def extract_acoustic_code(speech):\n",
" vq_emb = codec_encoder(speech.unsqueeze(1))\n",
" _, vq, _, _, _ = codec_decoder.quantizer(vq_emb)\n",
" acoustic_code = vq.permute(\n",
" 1, 2, 0\n",
" )\n",
" return acoustic_code\n",
"\n",
"@torch.no_grad()\n",
"def text2semantic(prompt_speech, prompt_text, prompt_language, target_text, target_language, target_len=None, n_timesteps=50, cfg=2.5, rescale_cfg=0.75):\n",
" \n",
" prompt_phone_id = g2p_(prompt_text, prompt_language)[1]\n",
"\n",
" target_phone_id = g2p_(target_text, target_language)[1]\n",
"\n",
" if target_len is None:\n",
" target_len = int((len(prompt_speech) * len(target_phone_id) / len(prompt_phone_id)) / 16000 * 50)\n",
" else:\n",
" target_len = int(target_len * 50)\n",
"\n",
" prompt_phone_id = torch.tensor(prompt_phone_id, dtype=torch.long).to(device)\n",
" target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device)\n",
"\n",
" phone_id = torch.cat([prompt_phone_id, target_phone_id]) \n",
"\n",
" input_fetures, attention_mask = extract_features(prompt_speech, processor)\n",
" input_fetures = input_fetures.unsqueeze(0).to(device)\n",
" attention_mask = attention_mask.unsqueeze(0).to(device)\n",
" semantic_code, _ = extract_semantic_code(semantic_mean, semantic_std, input_fetures, attention_mask)\n",
"\n",
" predict_semantic = t2s_model.reverse_diffusion(semantic_code[:, :], target_len, phone_id.unsqueeze(0), n_timesteps=n_timesteps, cfg=cfg, rescale_cfg=rescale_cfg)\n",
"\n",
" print(\"predict semantic shape\", predict_semantic.shape)\n",
"\n",
" combine_semantic_code = torch.cat([semantic_code[:,:], predict_semantic], dim=-1)\n",
" prompt_semantic_code = semantic_code\n",
"\n",
" return combine_semantic_code, prompt_semantic_code\n",
"\n",
"@torch.no_grad()\n",
"def semantic2acoustic(combine_semantic_code, acoustic_code, n_timesteps=[25,10,1,1,1,1,1,1,1,1,1,1], cfg=2.5, rescale_cfg=0.75):\n",
"\n",
" semantic_code = combine_semantic_code\n",
" \n",
" cond = s2a_model_1layer.cond_emb(semantic_code)\n",
" prompt = acoustic_code[:,:,:]\n",
" predict_1layer = s2a_model_1layer.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=n_timesteps[:1], cfg=cfg, rescale_cfg=rescale_cfg)\n",
"\n",
" cond = s2a_model_full.cond_emb(semantic_code)\n",
" prompt = acoustic_code[:,:,:]\n",
" predict_full = s2a_model_full.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=n_timesteps, cfg=cfg, rescale_cfg=rescale_cfg, gt_code=predict_1layer)\n",
" \n",
" vq_emb = codec_decoder.vq2emb(predict_full.permute(2,0,1), n_quantizers=12)\n",
" recovered_audio = codec_decoder(vq_emb)\n",
" prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2,0,1), n_quantizers=12)\n",
" recovered_prompt_audio = codec_decoder(prompt_vq_emb)\n",
" recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy()\n",
" recovered_audio = recovered_audio[0][0].cpu().numpy()\n",
" combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])\n",
"\n",
" return combine_audio, recovered_audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def maskgct_inference(prompt_speech_path, prompt_text, target_text, language=\"en\", target_language=\"en\", target_len=None, n_timesteps=25, cfg=2.5, rescale_cfg=0.75, n_timesteps_s2a=[25,10,1,1,1,1,1,1,1,1,1,1], cfg_s2a=2.5, rescale_cfg_s2a=0.75):\n",
" speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]\n",
" speech = librosa.load(prompt_speech_path, sr=24000)[0]\n",
"\n",
" combine_semantic_code, _ = text2semantic(speech_16k, prompt_text, language, target_text, target_language, target_len, n_timesteps, cfg, rescale_cfg)\n",
" acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device))\n",
" _, recovered_audio = semantic2acoustic(combine_semantic_code, acoustic_code, n_timesteps=n_timesteps_s2a, cfg=cfg_s2a, rescale_cfg=rescale_cfg_s2a)\n",
"\n",
" return recovered_audio"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Build Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\")\n",
"cfg_path = \"./models/tts/maskgct/config/maskgct.json\"\n",
"cfg = load_config(cfg_path)\n",
"\n",
"# 1. build semantic model (w2v-bert-2.0)\n",
"semantic_model, semantic_mean, semantic_std = build_semantic_model(device)\n",
"# 2. build semantic codec\n",
"semantic_codec = build_semantic_codec(cfg.model.semantic_codec, device)\n",
"# 3. build acoustic codec\n",
"codec_encoder, codec_decoder = build_acoustic_codec(cfg.model.acoustic_codec, device)\n",
"# 4. build t2s model\n",
"t2s_model = build_t2s_model(cfg.model.t2s_model, device)\n",
"# 5. build s2a model\n",
"s2a_model_1layer = build_s2a_model(cfg.model.s2a_model.s2a_1layer, device)\n",
"s2a_model_full = build_s2a_model(cfg.model.s2a_model.s2a_full, device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load Checkpoints"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import hf_hub_download\n",
"\n",
"# download semantic codec ckpt\n",
"semantic_code_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"semantic_codec/model.safetensors\")\n",
"# download acoustic codec ckpt\n",
"codec_encoder_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"acoustic_codec/model.safetensors\")\n",
"codec_decoder_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"acoustic_codec/model_1.safetensors\")\n",
"# download t2s model ckpt\n",
"t2s_model_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"t2s_model/model.safetensors\")\n",
"# download s2a model ckpt\n",
"s2a_1layer_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"s2a_model/s2a_model_1layer/model.safetensors\")\n",
"s2a_full_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"s2a_model/s2a_model_full/model.safetensors\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load semantic codec\n",
"safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)\n",
"# load acoustic codec\n",
"safetensors.torch.load_model(codec_encoder, codec_encoder_ckpt)\n",
"safetensors.torch.load_model(codec_decoder, codec_decoder_ckpt)\n",
"# load t2s model\n",
"safetensors.torch.load_model(t2s_model, t2s_model_ckpt)\n",
"# load s2a model\n",
"safetensors.torch.load_model(s2a_model_1layer, s2a_1layer_ckpt)\n",
"safetensors.torch.load_model(s2a_model_full, s2a_full_ckpt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt_wav_path = \"./models/tts/maskgct/wav/prompt.wav\"\n",
"prompt_text = \" We do not break. We never give in. We never back down.\"\n",
"target_text = \"In this paper, we introduce MaskGCT, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision.\"\n",
"target_len = 18 # Specify the target duration (in seconds). If target_len = None, we use a simple rule to predict the target duration.\n",
"recovered_audio = maskgct_inference(prompt_wav_path, prompt_text, target_text, \"en\", \"en\", target_len=target_len)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Audio\n",
"Audio(recovered_audio, rate=24000)"
]
}
],
"metadata": {
"fileId": "8353ad98-61bb-49ea-b655-c8f6a3264cc3",
"filePath": "/opt/tiger/SpeechGeneration2/models/tts/maskgct/maskgct_demo.ipynb",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading