diff --git a/inference.py b/inference.py index 45b1d9f..a2c130b 100644 --- a/inference.py +++ b/inference.py @@ -45,8 +45,9 @@ # Load additional modules from modules.campplus.DTDNN import CAMPPlus +campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None) campplus_model = CAMPPlus(feat_dim=80, embedding_size=192) -campplus_model.load_state_dict(torch.load(config['model_params']['style_encoder']['campplus_path'], map_location='cpu')) +campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu")) campplus_model.eval() campplus_model.to(device) @@ -103,6 +104,7 @@ def main(args): diffusion_steps = args.diffusion_steps length_adjust = args.length_adjust inference_cfg_rate = args.inference_cfg_rate + n_quantizers = args.n_quantizers source_audio = librosa.load(source, sr=sr)[0] ref_audio = librosa.load(target_name, sr=sr)[0] # decoded_wav = encodec_model.decoder(encodec_latent) @@ -117,43 +119,53 @@ def main(args): source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000) ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000) - S_alt = [ - cosyvoice_frontend.extract_speech_token(source_waves_16k, ) - ] - S_alt_lens = torch.LongTensor([s[1] for s in S_alt]).to(device) - S_alt = torch.cat([torch.nn.functional.pad(s[0], (0, max(S_alt_lens) - s[0].size(1))) for s in S_alt], dim=0) - - S_ori = [ - cosyvoice_frontend.extract_speech_token(ref_waves_16k, ) - ] - S_ori_lens = torch.LongTensor([s[1] for s in S_ori]).to(device) - S_ori = torch.cat([torch.nn.functional.pad(s[0], (0, max(S_ori_lens) - s[0].size(1))) for s in S_ori], dim=0) + if speech_tokenizer_type == "cosyvoice": + S_alt = cosyvoice_frontend.extract_speech_token(source_waves_16k)[0] + S_ori = cosyvoice_frontend.extract_speech_token(ref_waves_16k)[0] + elif speech_tokenizer_type == "facodec": + converted_waves_24k = torchaudio.functional.resample(source_audio, sr, 24000) + wave_lengths_24k = torch.LongTensor([converted_waves_24k.size(1)]).to(converted_waves_24k.device) + waves_input = converted_waves_24k.unsqueeze(1) + z = codec_encoder.encoder(waves_input) + (quantized, codes) = codec_encoder.quantizer(z, waves_input) + S_alt = torch.cat([codes[1], codes[0]], dim=1) + + # S_ori should be extracted in the same way + waves_24k = torchaudio.functional.resample(ref_audio, sr, 24000) + waves_input = waves_24k.unsqueeze(1) + z = codec_encoder.encoder(waves_input) + (quantized, codes) = codec_encoder.quantizer(z, waves_input) + S_ori = torch.cat([codes[1], codes[0]], dim=1) mel = to_mel(source_audio.to(device).float()) mel2 = to_mel(ref_audio.to(device).float()) - target = mel - target2 = mel2 - - target_lengths = torch.LongTensor([int(target.size(2) * length_adjust)]).to(target.device) - target2_lengths = torch.LongTensor([target2.size(2)]).to(target2.device) + target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device) + target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device) - feat2 = kaldi.fbank(ref_waves_16k, - num_mel_bins=80, - dither=0, - sample_frequency=16000) + feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k, + num_mel_bins=80, + dither=0, + sample_frequency=16000) feat2 = feat2 - feat2.mean(dim=0, keepdim=True) style2 = campplus_model(feat2.unsqueeze(0)) - cond = model.length_regulator(S_alt, ylens=target_lengths)[0] - prompt_condition = model.length_regulator(S_ori, ylens=target2_lengths)[0] + # Length regulation + cond = model.length_regulator(S_alt, ylens=target_lengths, n_quantizers=int(n_quantizers))[0] + prompt_condition = model.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=int(n_quantizers))[0] cat_condition = torch.cat([prompt_condition, cond], dim=1) - prompt_target = target2 time_vc_start = time.time() - vc_target = model.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(prompt_target.device), prompt_target, style2, None, diffusion_steps, inference_cfg_rate=inference_cfg_rate) - vc_target = vc_target[:, :, prompt_target.size(-1):] + vc_target = model.cfm.inference( + cat_condition, + torch.LongTensor([cat_condition.size(1)]).to(mel2.device), + mel2, style2, None, diffusion_steps, + inference_cfg_rate=inference_cfg_rate) + vc_target = vc_target[:, :, mel2.size(-1):] + + # Convert to waveform vc_wave = hift_gen.inference(vc_target) + time_vc_end = time.time() print(f"RTF: {(time_vc_end - time_vc_start) / vc_wave.size(-1) * sr}") @@ -163,11 +175,10 @@ def main(args): torchaudio.save(os.path.join(args.output, f"vc_{source_name}_{target_name}_{length_adjust}_{diffusion_steps}_{inference_cfg_rate}.wav"), vc_wave.cpu(), sr) - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--source", type=str, default="./examples/source/source_s1.wav") - parser.add_argument("--target", type=str, default="./examples/target/s1p1.wav") + parser.add_argument("--target", type=str, default="./examples/reference/s1p1.wav") parser.add_argument("--output", type=str, default="./reconstructed") parser.add_argument("--diffusion-steps", type=int, default=10) parser.add_argument("--length-adjust", type=float, default=1.0) diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..ae8d6ae --- /dev/null +++ b/ruff.toml @@ -0,0 +1,70 @@ +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", +] +extend-exclude = [] +line-length = 88 +indent-width = 4 +target-version = "py310" +show-fixes = true +src = [".", "modules"] + +[lint] +select = [ + "E", "F", "B", "Q", "I", "C90", "N", "D", "UP", "YTT", "ANN", "S", "BLE", + "FBT", "A", "COM", "C4", "DTZ", "T10", "EM", "EXE", "ISC", "ICN", "INP", + "PIE", "T20", "PT", "Q", "RET", "SIM", "ARG", "ERA", "PD", "PGH", "PL", + "TRY", "RUF", +] +ignore = [ + "D105", + "D107", + "D203", + "D213", + "S101", # assert-used + "INP001", # implicit-namespace-package + "ANN101", # missing-type-self + "ANN102", # missing-type-cls + "ANN204", # missing-return-type-special-method + "ERA001", # commented-out-code + "ANN002", # missing-type-args + "ANN003", # missing-type-kwargs + "RET504", # unnecessary-assign + "COM812", # TBD: some conflict + "ISC001", # TBD: some conflict +] +fixable = ["ALL"] +unfixable = [] + +[format] +quote-style = "double" +indent-style = "space" + +[lint.isort] +# force-sort-within-sections and lines-between-types should be incompatible +force-sort-within-sections = false +lines-between-types = 1 +force-single-line = true +no-sections = false +from-first = false + +[lint.pydocstyle] +convention = "google"