Skip to content

Commit

Permalink
Fix Inference (#2)
Browse files Browse the repository at this point in the history
- load campplus from huggingface
- load tokenizer dynamically
  • Loading branch information
ppmzhang2 authored Sep 17, 2024
1 parent 74a535f commit fab344c
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 28 deletions.
67 changes: 39 additions & 28 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@
# Load additional modules
from modules.campplus.DTDNN import CAMPPlus

campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
campplus_model.load_state_dict(torch.load(config['model_params']['style_encoder']['campplus_path'], map_location='cpu'))
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
campplus_model.eval()
campplus_model.to(device)

Expand Down Expand Up @@ -103,6 +104,7 @@ def main(args):
diffusion_steps = args.diffusion_steps
length_adjust = args.length_adjust
inference_cfg_rate = args.inference_cfg_rate
n_quantizers = args.n_quantizers
source_audio = librosa.load(source, sr=sr)[0]
ref_audio = librosa.load(target_name, sr=sr)[0]
# decoded_wav = encodec_model.decoder(encodec_latent)
Expand All @@ -117,43 +119,53 @@ def main(args):
source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)

S_alt = [
cosyvoice_frontend.extract_speech_token(source_waves_16k, )
]
S_alt_lens = torch.LongTensor([s[1] for s in S_alt]).to(device)
S_alt = torch.cat([torch.nn.functional.pad(s[0], (0, max(S_alt_lens) - s[0].size(1))) for s in S_alt], dim=0)

S_ori = [
cosyvoice_frontend.extract_speech_token(ref_waves_16k, )
]
S_ori_lens = torch.LongTensor([s[1] for s in S_ori]).to(device)
S_ori = torch.cat([torch.nn.functional.pad(s[0], (0, max(S_ori_lens) - s[0].size(1))) for s in S_ori], dim=0)
if speech_tokenizer_type == "cosyvoice":
S_alt = cosyvoice_frontend.extract_speech_token(source_waves_16k)[0]
S_ori = cosyvoice_frontend.extract_speech_token(ref_waves_16k)[0]
elif speech_tokenizer_type == "facodec":
converted_waves_24k = torchaudio.functional.resample(source_audio, sr, 24000)
wave_lengths_24k = torch.LongTensor([converted_waves_24k.size(1)]).to(converted_waves_24k.device)
waves_input = converted_waves_24k.unsqueeze(1)
z = codec_encoder.encoder(waves_input)
(quantized, codes) = codec_encoder.quantizer(z, waves_input)
S_alt = torch.cat([codes[1], codes[0]], dim=1)

# S_ori should be extracted in the same way
waves_24k = torchaudio.functional.resample(ref_audio, sr, 24000)
waves_input = waves_24k.unsqueeze(1)
z = codec_encoder.encoder(waves_input)
(quantized, codes) = codec_encoder.quantizer(z, waves_input)
S_ori = torch.cat([codes[1], codes[0]], dim=1)

mel = to_mel(source_audio.to(device).float())
mel2 = to_mel(ref_audio.to(device).float())

target = mel
target2 = mel2

target_lengths = torch.LongTensor([int(target.size(2) * length_adjust)]).to(target.device)
target2_lengths = torch.LongTensor([target2.size(2)]).to(target2.device)
target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)

feat2 = kaldi.fbank(ref_waves_16k,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
style2 = campplus_model(feat2.unsqueeze(0))

cond = model.length_regulator(S_alt, ylens=target_lengths)[0]
prompt_condition = model.length_regulator(S_ori, ylens=target2_lengths)[0]
# Length regulation
cond = model.length_regulator(S_alt, ylens=target_lengths, n_quantizers=int(n_quantizers))[0]
prompt_condition = model.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=int(n_quantizers))[0]
cat_condition = torch.cat([prompt_condition, cond], dim=1)
prompt_target = target2

time_vc_start = time.time()
vc_target = model.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(prompt_target.device), prompt_target, style2, None, diffusion_steps, inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, prompt_target.size(-1):]
vc_target = model.cfm.inference(
cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2, style2, None, diffusion_steps,
inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, mel2.size(-1):]

# Convert to waveform
vc_wave = hift_gen.inference(vc_target)

time_vc_end = time.time()
print(f"RTF: {(time_vc_end - time_vc_start) / vc_wave.size(-1) * sr}")

Expand All @@ -163,11 +175,10 @@ def main(args):
torchaudio.save(os.path.join(args.output, f"vc_{source_name}_{target_name}_{length_adjust}_{diffusion_steps}_{inference_cfg_rate}.wav"), vc_wave.cpu(), sr)



if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--source", type=str, default="./examples/source/source_s1.wav")
parser.add_argument("--target", type=str, default="./examples/target/s1p1.wav")
parser.add_argument("--target", type=str, default="./examples/reference/s1p1.wav")
parser.add_argument("--output", type=str, default="./reconstructed")
parser.add_argument("--diffusion-steps", type=int, default=10)
parser.add_argument("--length-adjust", type=float, default=1.0)
Expand Down
70 changes: 70 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".hg",
".mypy_cache",
".nox",
".pants.d",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]
extend-exclude = []
line-length = 88
indent-width = 4
target-version = "py310"
show-fixes = true
src = [".", "modules"]

[lint]
select = [
"E", "F", "B", "Q", "I", "C90", "N", "D", "UP", "YTT", "ANN", "S", "BLE",
"FBT", "A", "COM", "C4", "DTZ", "T10", "EM", "EXE", "ISC", "ICN", "INP",
"PIE", "T20", "PT", "Q", "RET", "SIM", "ARG", "ERA", "PD", "PGH", "PL",
"TRY", "RUF",
]
ignore = [
"D105",
"D107",
"D203",
"D213",
"S101", # assert-used
"INP001", # implicit-namespace-package
"ANN101", # missing-type-self
"ANN102", # missing-type-cls
"ANN204", # missing-return-type-special-method
"ERA001", # commented-out-code
"ANN002", # missing-type-args
"ANN003", # missing-type-kwargs
"RET504", # unnecessary-assign
"COM812", # TBD: some conflict
"ISC001", # TBD: some conflict
]
fixable = ["ALL"]
unfixable = []

[format]
quote-style = "double"
indent-style = "space"

[lint.isort]
# force-sort-within-sections and lines-between-types should be incompatible
force-sort-within-sections = false
lines-between-types = 1
force-single-line = true
no-sections = false
from-first = false

[lint.pydocstyle]
convention = "google"

0 comments on commit fab344c

Please sign in to comment.