forked from myshell-ai/OpenVoice
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathpredict.py
99 lines (79 loc) · 3.16 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import shutil
import subprocess
import time
import torch
from cog import BasePredictor, Input, Path
from melo.api import TTS
from openvoice import se_extractor
from openvoice.api import ToneColorConverter
SUPPORTED_LANGUAGES = ["EN_NEWEST", "EN", "ES", "FR", "ZH", "JP", "KR"]
MODEL_URL = "https://weights.replicate.delivery/default/myshell-ai/OpenVoice-v2.tar"
MODEL_CACHE = "model_cache"
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
if not os.path.exists(MODEL_CACHE):
download_weights(MODEL_URL, MODEL_CACHE)
ckpt_converter = f"{MODEL_CACHE}/checkpoints_v2/converter"
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.tone_color_converter = ToneColorConverter(
f"{ckpt_converter}/config.json", device=self.device
)
self.tone_color_converter.load_ckpt(f"{ckpt_converter}/checkpoint.pth")
def predict(
self,
audio: Path = Input(description="Input reference audio"),
text: str = Input(
description="Input text",
default="Did you ever hear a folk tale about a giant turtle?",
),
language: str = Input(
description="The language of the audio to be generated",
choices=SUPPORTED_LANGUAGES,
default="EN_NEWEST",
),
speed: float = Input(
description="Set speed scale of the output audio", default=1.0
),
) -> Path:
"""Run a single prediction on the model"""
target_dir = "exp_dir"
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
target_se, audio_name = se_extractor.get_se(
str(audio),
self.tone_color_converter,
target_dir=f"{target_dir}/processed",
vad=False,
)
model = TTS(language=language, device=self.device)
speaker_ids = model.hps.data.spk2id
src_path = f"{target_dir}/tmp.wav"
out_path = "/tmp/out.wav"
for speaker_key in speaker_ids.keys():
speaker_id = speaker_ids[speaker_key]
speaker_key = speaker_key.lower().replace("_", "-")
source_se = torch.load(
f"{MODEL_CACHE}/checkpoints_v2/base_speakers/ses/{speaker_key}.pth",
map_location=self.device,
)
model.tts_to_file(text, speaker_id, src_path, speed=speed)
# Run the tone color converter
encode_message = "@MyShell"
self.tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=out_path,
message=encode_message,
)
return Path(out_path)