-
Notifications
You must be signed in to change notification settings - Fork 168
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2064528
commit b679318
Showing
2 changed files
with
405 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,295 @@ | ||
#!/usr/bin/env python3 | ||
from pathlib import Path | ||
from typing import Any, Dict | ||
|
||
import math | ||
import onnx | ||
import torch | ||
import argparse | ||
|
||
from onnxruntime.quantization import QuantType, quantize_dynamic | ||
|
||
import utils | ||
import commons | ||
import attentions | ||
from torch import nn | ||
from models import DurationPredictor, ResidualCouplingBlock, Generator | ||
from text.symbols import symbols | ||
|
||
|
||
class TextEncoder(nn.Module): | ||
def __init__( | ||
self, | ||
n_vocab, | ||
out_channels, | ||
hidden_channels, | ||
filter_channels, | ||
n_heads, | ||
n_layers, | ||
kernel_size, | ||
p_dropout, | ||
): | ||
super().__init__() | ||
self.n_vocab = n_vocab | ||
self.out_channels = out_channels | ||
self.hidden_channels = hidden_channels | ||
self.filter_channels = filter_channels | ||
self.n_heads = n_heads | ||
self.n_layers = n_layers | ||
self.kernel_size = kernel_size | ||
self.p_dropout = p_dropout | ||
|
||
self.emb = nn.Embedding(n_vocab, hidden_channels) | ||
# self.emb_bert = nn.Linear(256, hidden_channels) | ||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) | ||
|
||
self.encoder = attentions.Encoder( | ||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout | ||
) | ||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) | ||
|
||
def forward(self, x, x_lengths): | ||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] | ||
# if bert is not None: | ||
# b = self.emb_bert(bert) | ||
# x = x + b | ||
x = torch.transpose(x, 1, -1) # [b, h, t] | ||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( | ||
x.dtype | ||
) | ||
|
||
x = self.encoder(x * x_mask, x_mask) | ||
stats = self.proj(x) * x_mask | ||
|
||
m, logs = torch.split(stats, self.out_channels, dim=1) | ||
return x, m, logs, x_mask | ||
|
||
|
||
class SynthesizerEval(nn.Module): | ||
""" | ||
Synthesizer for Training | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n_vocab, | ||
spec_channels, | ||
segment_size, | ||
inter_channels, | ||
hidden_channels, | ||
filter_channels, | ||
n_heads, | ||
n_layers, | ||
kernel_size, | ||
p_dropout, | ||
resblock, | ||
resblock_kernel_sizes, | ||
resblock_dilation_sizes, | ||
upsample_rates, | ||
upsample_initial_channel, | ||
upsample_kernel_sizes, | ||
n_speakers=0, | ||
gin_channels=0, | ||
use_sdp=False, | ||
**kwargs | ||
): | ||
|
||
super().__init__() | ||
self.n_vocab = n_vocab | ||
self.spec_channels = spec_channels | ||
self.inter_channels = inter_channels | ||
self.hidden_channels = hidden_channels | ||
self.filter_channels = filter_channels | ||
self.n_heads = n_heads | ||
self.n_layers = n_layers | ||
self.kernel_size = kernel_size | ||
self.p_dropout = p_dropout | ||
self.resblock = resblock | ||
self.resblock_kernel_sizes = resblock_kernel_sizes | ||
self.resblock_dilation_sizes = resblock_dilation_sizes | ||
self.upsample_rates = upsample_rates | ||
self.upsample_initial_channel = upsample_initial_channel | ||
self.upsample_kernel_sizes = upsample_kernel_sizes | ||
self.segment_size = segment_size | ||
self.n_speakers = n_speakers | ||
self.gin_channels = gin_channels | ||
|
||
self.enc_p = TextEncoder( | ||
n_vocab, | ||
inter_channels, | ||
hidden_channels, | ||
filter_channels, | ||
n_heads, | ||
n_layers, | ||
kernel_size, | ||
p_dropout, | ||
) | ||
self.dec = Generator( | ||
inter_channels, | ||
resblock, | ||
resblock_kernel_sizes, | ||
resblock_dilation_sizes, | ||
upsample_rates, | ||
upsample_initial_channel, | ||
upsample_kernel_sizes, | ||
gin_channels=gin_channels, | ||
) | ||
self.flow = ResidualCouplingBlock( | ||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels | ||
) | ||
self.dp = DurationPredictor( | ||
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels | ||
) | ||
if n_speakers > 1: | ||
self.emb_g = nn.Embedding(n_speakers, gin_channels) | ||
|
||
def remove_weight_norm(self): | ||
self.flow.remove_weight_norm() | ||
|
||
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1): | ||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) | ||
if self.n_speakers > 0: | ||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] | ||
else: | ||
g = None | ||
|
||
logw = self.dp(x, x_mask, g=g) | ||
w = torch.exp(logw) * x_mask * length_scale | ||
w_ceil = torch.ceil(w) | ||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() | ||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to( | ||
x_mask.dtype | ||
) | ||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) | ||
attn = commons.generate_path(w_ceil, attn_mask) | ||
|
||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose( | ||
1, 2 | ||
) # [b, t', t], [b, t, d] -> [b, d, t'] | ||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose( | ||
1, 2 | ||
) # [b, t', t], [b, t, d] -> [b, d, t'] | ||
|
||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale | ||
z = self.flow(z_p, y_mask, g=g, reverse=True) | ||
o = self.dec((z * y_mask), g=g) | ||
return o.squeeze() | ||
|
||
|
||
class OnnxModel(torch.nn.Module): | ||
def __init__(self, model: SynthesizerEval): | ||
super().__init__() | ||
self.model = model | ||
|
||
def forward( | ||
self, | ||
x, | ||
x_lengths, | ||
noise_scale=1, | ||
length_scale=1, | ||
): | ||
return self.model.infer( | ||
x=x, | ||
x_lengths=x_lengths, | ||
noise_scale=noise_scale, | ||
length_scale=length_scale, | ||
) | ||
|
||
|
||
def add_meta_data(filename: str, meta_data: Dict[str, Any]): | ||
"""Add meta data to an ONNX model. It is changed in-place. | ||
Args: | ||
filename: | ||
Filename of the ONNX model to be changed. | ||
meta_data: | ||
Key-value pairs. | ||
""" | ||
model = onnx.load(filename) | ||
for key, value in meta_data.items(): | ||
meta = model.metadata_props.add() | ||
meta.key = key | ||
meta.value = str(value) | ||
|
||
onnx.save(model, filename) | ||
|
||
|
||
@torch.no_grad() | ||
def main(): | ||
parser = argparse.ArgumentParser(description='Inference code for bert vits models') | ||
parser.add_argument('--config', type=str, required=True) | ||
parser.add_argument('--model', type=str, required=True) | ||
args = parser.parse_args() | ||
config_file = args.config | ||
checkpoint = args.model | ||
|
||
hps = utils.get_hparams_from_file(config_file) | ||
print(hps) | ||
|
||
net_g = SynthesizerEval( | ||
len(symbols), | ||
hps.data.filter_length // 2 + 1, | ||
hps.train.segment_size // hps.data.hop_length, | ||
n_speakers=hps.data.n_speakers, | ||
**hps.model, | ||
) | ||
|
||
_ = net_g.eval() | ||
_ = utils.load_model(checkpoint, net_g) | ||
net_g.remove_weight_norm() | ||
|
||
x = torch.randint(low=0, high=100, size=(50,), dtype=torch.int64) | ||
x = x.unsqueeze(0) | ||
|
||
x_length = torch.tensor([x.shape[1]], dtype=torch.int64) | ||
noise_scale = torch.tensor([1], dtype=torch.float32) | ||
length_scale = torch.tensor([1], dtype=torch.float32) | ||
|
||
model = OnnxModel(net_g) | ||
|
||
opset_version = 13 | ||
|
||
filename = "vits-chinese.onnx" | ||
|
||
torch.onnx.export( | ||
model, | ||
(x, x_length, noise_scale, length_scale), | ||
filename, | ||
opset_version=opset_version, | ||
input_names=[ | ||
"x", | ||
"x_length", | ||
"noise_scale", | ||
"length_scale", | ||
], | ||
output_names=["y"], | ||
dynamic_axes={ | ||
"x": {0: "N", 1: "L"}, # n_audio is also known as batch_size | ||
"x_length": {0: "N"}, | ||
"y": {0: "N", 2: "L"}, | ||
}, | ||
) | ||
meta_data = { | ||
"model_type": "vits", | ||
"comment": "csukuangfj", | ||
"language": "Chinese", | ||
"add_blank": int(hps.data.add_blank), | ||
"n_speakers": int(hps.data.n_speakers), | ||
"sample_rate": hps.data.sampling_rate, | ||
"punctuation": "", | ||
} | ||
print("meta_data", meta_data) | ||
add_meta_data(filename=filename, meta_data=meta_data) | ||
|
||
print("Generate int8 quantization models") | ||
filename_int8 = "vits-chinese.int8.onnx" | ||
quantize_dynamic( | ||
model_input=filename, | ||
model_output=filename_int8, | ||
weight_type=QuantType.QUInt8, | ||
) | ||
print(f"Saved to {filename} and {filename_int8}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.