diff --git a/model_onnx.py b/model_onnx.py new file mode 100644 index 0000000..1e4de56 --- /dev/null +++ b/model_onnx.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +from pathlib import Path +from typing import Any, Dict + +import math +import onnx +import torch +import argparse + +from onnxruntime.quantization import QuantType, quantize_dynamic + +import utils +import commons +import attentions +from torch import nn +from models import DurationPredictor, ResidualCouplingBlock, Generator +from text.symbols import symbols + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + # self.emb_bert = nn.Linear(256, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + # if bert is not None: + # b = self.emb_bert(bert) + # x = x + b + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class SynthesizerEval(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + n_vocab, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=False, + **kwargs + ): + + super().__init__() + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + + self.enc_p = TextEncoder( + n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels + ) + self.dp = DurationPredictor( + hidden_channels, 256, 3, 0.5, gin_channels=gin_channels + ) + if n_speakers > 1: + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + def remove_weight_norm(self): + self.flow.remove_weight_norm() + + def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + logw = self.dp(x, x_mask, g=g) + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to( + x_mask.dtype + ) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask), g=g) + return o.squeeze() + + +class OnnxModel(torch.nn.Module): + def __init__(self, model: SynthesizerEval): + super().__init__() + self.model = model + + def forward( + self, + x, + x_lengths, + noise_scale=1, + length_scale=1, + ): + return self.model.infer( + x=x, + x_lengths=x_lengths, + noise_scale=noise_scale, + length_scale=length_scale, + ) + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +@torch.no_grad() +def main(): + parser = argparse.ArgumentParser(description='Inference code for bert vits models') + parser.add_argument('--config', type=str, required=True) + parser.add_argument('--model', type=str, required=True) + args = parser.parse_args() + config_file = args.config + checkpoint = args.model + + hps = utils.get_hparams_from_file(config_file) + print(hps) + + net_g = SynthesizerEval( + len(symbols), + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + + _ = net_g.eval() + _ = utils.load_model(checkpoint, net_g) + net_g.remove_weight_norm() + + x = torch.randint(low=0, high=100, size=(50,), dtype=torch.int64) + x = x.unsqueeze(0) + + x_length = torch.tensor([x.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + length_scale = torch.tensor([1], dtype=torch.float32) + + model = OnnxModel(net_g) + + opset_version = 13 + + filename = "vits-chinese.onnx" + + torch.onnx.export( + model, + (x, x_length, noise_scale, length_scale), + filename, + opset_version=opset_version, + input_names=[ + "x", + "x_length", + "noise_scale", + "length_scale", + ], + output_names=["y"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, # n_audio is also known as batch_size + "x_length": {0: "N"}, + "y": {0: "N", 2: "L"}, + }, + ) + meta_data = { + "model_type": "vits", + "comment": "csukuangfj", + "language": "Chinese", + "add_blank": int(hps.data.add_blank), + "n_speakers": int(hps.data.n_speakers), + "sample_rate": hps.data.sampling_rate, + "punctuation": "", + } + print("meta_data", meta_data) + add_meta_data(filename=filename, meta_data=meta_data) + + print("Generate int8 quantization models") + filename_int8 = "vits-chinese.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + weight_type=QuantType.QUInt8, + ) + print(f"Saved to {filename} and {filename_int8}") + + +if __name__ == "__main__": + main() diff --git a/vits_infer_onnx.py b/vits_infer_onnx.py new file mode 100644 index 0000000..ff912da --- /dev/null +++ b/vits_infer_onnx.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +import onnxruntime +import soundfile +import torch +import os +import torch +import argparse + +from text import cleaned_text_to_sequence +from vits_pinyin import VITS_PinYin + + +def display(sess): + for i in sess.get_inputs(): + print(i) + + print("-" * 10) + for o in sess.get_outputs(): + print(o) + + +class OnnxModel: + def __init__( + self, + model: str, + ): + session_opts = onnxruntime.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.model = onnxruntime.InferenceSession( + model, + sess_options=self.session_opts, + ) + display(self.model) + + meta = self.model.get_modelmeta().custom_metadata_map + self.add_blank = int(meta["add_blank"]) + self.sample_rate = int(meta["sample_rate"]) + print(meta) + + def __call__( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x: + A int64 tensor of shape (L,) + """ + x = x.unsqueeze(0) + x_length = torch.tensor([x.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + length_scale = torch.tensor([1], dtype=torch.float32) + + y = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_length.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: length_scale.numpy(), + }, + )[0] + return y + + +def main(): + parser = argparse.ArgumentParser( + description='Inference code for bert vits models') + parser.add_argument('--model', type=str, required=True) + args = parser.parse_args() + print("Onnx model path:", args.model) + model = OnnxModel(args.model) + + tts_front = VITS_PinYin(None, None, hasBert=False) + + os.makedirs("./vits_infer_out/", exist_ok=True) + + n = 0 + fo = open("vits_infer_item.txt", "r+", encoding='utf-8') + while (True): + try: + item = fo.readline().strip() + except Exception as e: + print('nothing of except:', e) + break + if (item == None or item == ""): + break + n = n + 1 + phonemes, _ = tts_front.chinese_to_phonemes(item) + input_ids = cleaned_text_to_sequence(phonemes) + + x = torch.tensor(input_ids, dtype=torch.int64) + y = model(x) + + soundfile.write( + f"./vits_infer_out/onnx_{n}.wav", y, model.sample_rate) + + fo.close() + + +if __name__ == "__main__": + main()