-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtts.py
124 lines (93 loc) · 3.48 KB
/
tts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import datetime as dt
from pathlib import Path
import IPython.display as ipd
import numpy as np
import soundfile as sf
import torch
from tqdm.auto import tqdm
import io
# Hifigan imports
from matcha.hifigan.config import v1
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.env import AttrDict
from matcha.hifigan.models import Generator as HiFiGAN
# Matcha imports
from matcha.models.matcha_tts import MatchaTTS
from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.model import denormalize
from matcha.utils.utils import get_user_data_dir, intersperse
## Number of ODE Solver steps
n_timesteps = 10
## Changes to the speaking rate
length_scale=1.0
## Sampling temperature
temperature = 0.667
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MATCHA_CHECKPOINT = get_user_data_dir()/"matcha_ljspeech.ckpt"
HIFIGAN_CHECKPOINT = get_user_data_dir() / "hifigan_T2_v1"
OUTPUT_FOLDER = "synth_output"
def load_model(checkpoint_path):
model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
model.eval()
return model
count_params = lambda x: f"{sum(p.numel() for p in x.parameters()):,}"
model = load_model(MATCHA_CHECKPOINT)
def load_vocoder(checkpoint_path):
h = AttrDict(v1)
hifigan = HiFiGAN(h).to(device)
hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
vocoder = load_vocoder(HIFIGAN_CHECKPOINT)
denoiser = Denoiser(vocoder, mode='zeros')
@torch.inference_mode()
def process_text(text: str):
x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2']), 0),dtype=torch.long, device=device)[None]
x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)
x_phones = sequence_to_text(x.squeeze(0).tolist())
return {
'x_orig': text,
'x': x,
'x_lengths': x_lengths,
'x_phones': x_phones
}
@torch.inference_mode()
def synthesise(text, spks=None):
text_processed = process_text(text)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed['x'],
text_processed['x_lengths'],
n_timesteps=n_timesteps,
temperature=temperature,
spks=spks,
length_scale=length_scale
)
# merge everything to one dict
output.update({'start_t': start_t, **text_processed})
return output
@torch.inference_mode()
def to_waveform(mel, vocoder):
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()
def generate(text):
text = [text]
outputs, rtfs = [], []
rtfs_w = []
for i, text in enumerate(tqdm(text)):
output = synthesise(text) #, torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))
output['waveform'] = to_waveform(output['mel'], vocoder)
# Compute Real Time Factor (RTF) with HiFi-GAN
t = (dt.datetime.now() - output['start_t']).total_seconds()
rtf_w = t * 22050 / (output['waveform'].shape[-1])
rtfs.append(output['rtf'])
rtfs_w.append(rtf_w)
## Display the synthesised waveform
audio_bytesio = io.BytesIO()
waveform_bytes = output['waveform'].cpu().numpy().astype(np.int16)
sf.write(audio_bytesio, output['waveform'], 22050, 'PCM_24', format='WAV')
audio_data = audio_bytesio.getvalue()
audio_bytesio.seek(0)
return audio_bytesio