diff --git a/app.py b/app.py index a5b79af..81e2be4 100644 --- a/app.py +++ b/app.py @@ -125,9 +125,7 @@ def crossfade(chunk1, chunk2, overlap): return chunk2 # streaming and chunk processing related params -max_context_window = sr // hop_length * 30 overlap_frame_len = 16 -overlap_wave_len = overlap_frame_len * hop_length bitrate = "320k" @torch.no_grad() @@ -137,6 +135,9 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c mel_fn = to_mel if not f0_condition else to_mel_f0 bigvgan_fn = bigvgan_model if not f0_condition else bigvgan_44k_model sr = 22050 if not f0_condition else 44100 + hop_length = 256 if not f0_condition else 512 + max_context_window = sr // hop_length * 30 + overlap_wave_len = overlap_frame_len * hop_length # Load audio source_audio = librosa.load(source, sr=sr)[0] ref_audio = librosa.load(target, sr=sr)[0] diff --git a/inference.py b/inference.py index d3731e2..50ee2e7 100644 --- a/inference.py +++ b/inference.py @@ -1,4 +1,7 @@ import os + +import numpy as np + os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache' import shutil import warnings @@ -230,9 +233,18 @@ def adjust_f0_semitones(f0_sequence, n_semitones): factor = 2 ** (n_semitones / 12) return f0_sequence * factor +def crossfade(chunk1, chunk2, overlap): + fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2 + fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2 + if len(chunk2) < overlap: + chunk2[:overlap] = chunk2[:overlap] * fade_in[:len(chunk2)] + (chunk1[-overlap:] * fade_out)[:len(chunk2)] + else: + chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out + return chunk2 + @torch.no_grad() def main(args): - model, semantic_fn, f0_fn, vocoder_fn, campplus_model, to_mel, mel_fn_args = load_models(args) + model, semantic_fn, f0_fn, vocoder_fn, campplus_model, mel_fn, mel_fn_args = load_models(args) sr = mel_fn_args['sampling_rate'] f0_condition = args.f0_condition auto_f0_adjust = args.auto_f0_adjust @@ -246,27 +258,53 @@ def main(args): source_audio = librosa.load(source, sr=sr)[0] ref_audio = librosa.load(target_name, sr=sr)[0] - source_audio = source_audio[:sr * 30] - source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device) - - ref_audio = ref_audio[:(sr * 30 - source_audio.size(-1))] - ref_audio = torch.tensor(ref_audio).unsqueeze(0).float().to(device) + sr = 22050 if not f0_condition else 44100 + hop_length = 256 if not f0_condition else 512 + max_context_window = sr // hop_length * 30 + overlap_frame_len = 16 + overlap_wave_len = overlap_frame_len * hop_length - source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000) - ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000) + # Process audio + source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device) + ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device) + time_vc_start = time.time() + # Resample converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000) + # if source audio less than 30 seconds, whisper can handle in one forward + if converted_waves_16k.size(-1) <= 16000 * 30: + S_alt = semantic_fn(converted_waves_16k) + else: + overlapping_time = 5 # 5 seconds + S_alt_list = [] + buffer = None + traversed_time = 0 + while traversed_time < converted_waves_16k.size(-1): + if buffer is None: # first chunk + chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30] + else: + chunk = torch.cat( + [buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], + dim=-1) + S_alt = semantic_fn(chunk) + if traversed_time == 0: + S_alt_list.append(S_alt) + else: + S_alt_list.append(S_alt[:, 50 * overlapping_time:]) + buffer = chunk[:, -16000 * overlapping_time:] + traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time + S_alt = torch.cat(S_alt_list, dim=1) + ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000) - S_alt = semantic_fn(converted_waves_16k) S_ori = semantic_fn(ori_waves_16k) - mel = to_mel(source_audio.to(device).float()) - mel2 = to_mel(ref_audio.to(device).float()) + mel = mel_fn(source_audio.to(device).float()) + mel2 = mel_fn(ref_audio.to(device).float()) target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device) target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device) - feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k, + feat2 = torchaudio.compliance.kaldi.fbank(ori_waves_16k, num_mel_bins=80, dither=0, sample_frequency=16000) @@ -274,8 +312,8 @@ def main(args): style2 = campplus_model(feat2.unsqueeze(0)) if f0_condition: - F0_ori = f0_fn(ref_waves_16k[0], thred=0.03) - F0_alt = f0_fn(source_waves_16k[0], thred=0.03) + F0_ori = f0_fn(ori_waves_16k[0], thred=0.03) + F0_alt = f0_fn(converted_waves_16k[0], thred=0.03) F0_ori = torch.from_numpy(F0_ori).to(device)[None] F0_alt = torch.from_numpy(F0_alt).to(device)[None] @@ -288,6 +326,7 @@ def main(args): voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5) median_log_f0_ori = torch.median(voiced_log_f0_ori) median_log_f0_alt = torch.median(voiced_log_f0_alt) + # shift alt log f0 level to ori log f0 level shifted_log_f0_alt = log_f0_alt.clone() if auto_f0_adjust: @@ -301,22 +340,53 @@ def main(args): shifted_f0_alt = None # Length regulation - cond, _, codes, commitment_loss, codebook_loss = model.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt) - prompt_condition, _, prompt_codes, commitment_loss, codebook_loss = model.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori) - cat_condition = torch.cat([prompt_condition, cond], dim=1) - - time_vc_start = time.time() - with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32): - vc_target = model.cfm.inference( - cat_condition, - torch.LongTensor([cat_condition.size(1)]).to(mel2.device), - mel2, style2, None, diffusion_steps, - inference_cfg_rate=inference_cfg_rate) - vc_target = vc_target[:, :, mel2.size(-1):] - - # Convert to waveform - vc_wave = vocoder_fn(vc_target).squeeze() # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1] - vc_wave = vc_wave[None, :].float() + cond, _, codes, commitment_loss, codebook_loss = model.length_regulator(S_alt, ylens=target_lengths, + n_quantizers=3, + f0=shifted_f0_alt) + prompt_condition, _, codes, commitment_loss, codebook_loss = model.length_regulator(S_ori, + ylens=target2_lengths, + n_quantizers=3, + f0=F0_ori) + + max_source_window = max_context_window - mel2.size(2) + # split source condition (cond) into chunks + processed_frames = 0 + generated_wave_chunks = [] + # generate chunk by chunk and stream the output + while processed_frames < cond.size(1): + chunk_cond = cond[:, processed_frames:processed_frames + max_source_window] + is_last_chunk = processed_frames + max_source_window >= cond.size(1) + cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1) + with torch.autocast(device_type=device.type, dtype=torch.float16): + # Voice Conversion + vc_target = model.cfm.inference(cat_condition, + torch.LongTensor([cat_condition.size(1)]).to(mel2.device), + mel2, style2, None, diffusion_steps, + inference_cfg_rate=inference_cfg_rate) + vc_target = vc_target[:, :, mel2.size(-1):] + vc_wave = vocoder_fn(vc_target).squeeze() + vc_wave = vc_wave[None, :] + if processed_frames == 0: + if is_last_chunk: + output_wave = vc_wave[0].cpu().numpy() + generated_wave_chunks.append(output_wave) + break + output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy() + generated_wave_chunks.append(output_wave) + previous_chunk = vc_wave[0, -overlap_wave_len:] + processed_frames += vc_target.size(2) - overlap_frame_len + elif is_last_chunk: + output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len) + generated_wave_chunks.append(output_wave) + processed_frames += vc_target.size(2) - overlap_frame_len + break + else: + output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), + overlap_wave_len) + generated_wave_chunks.append(output_wave) + previous_chunk = vc_wave[0, -overlap_wave_len:] + processed_frames += vc_target.size(2) - overlap_frame_len + vc_wave = torch.tensor(np.concatenate(generated_wave_chunks))[None, :].float() time_vc_end = time.time() print(f"RTF: {(time_vc_end - time_vc_start) / vc_wave.size(-1) * sr}") @@ -334,8 +404,8 @@ def main(args): parser.add_argument("--diffusion-steps", type=int, default=30) parser.add_argument("--length-adjust", type=float, default=1.0) parser.add_argument("--inference-cfg-rate", type=float, default=0.7) - parser.add_argument("--f0-condition", type=str2bool, default=True) - parser.add_argument("--auto-f0-adjust", type=str2bool, default=True) + parser.add_argument("--f0-condition", type=str2bool, default=False) + parser.add_argument("--auto-f0-adjust", type=str2bool, default=False) parser.add_argument("--semi-tone-shift", type=int, default=0) parser.add_argument("--checkpoint-path", type=str, help="Path to the checkpoint file", default=None) parser.add_argument("--config-path", type=str, help="Path to the config file", default=None)