Skip to content

Commit

Permalink
fix chunk transition bug in app.py, add long-form inference support f…
Browse files Browse the repository at this point in the history
…or inference.py
  • Loading branch information
Plachtaa committed Nov 28, 2024
1 parent 761986a commit c7ef271
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 34 deletions.
5 changes: 3 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ def crossfade(chunk1, chunk2, overlap):
return chunk2

# streaming and chunk processing related params
max_context_window = sr // hop_length * 30
overlap_frame_len = 16
overlap_wave_len = overlap_frame_len * hop_length
bitrate = "320k"

@torch.no_grad()
Expand All @@ -137,6 +135,9 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
mel_fn = to_mel if not f0_condition else to_mel_f0
bigvgan_fn = bigvgan_model if not f0_condition else bigvgan_44k_model
sr = 22050 if not f0_condition else 44100
hop_length = 256 if not f0_condition else 512
max_context_window = sr // hop_length * 30
overlap_wave_len = overlap_frame_len * hop_length
# Load audio
source_audio = librosa.load(source, sr=sr)[0]
ref_audio = librosa.load(target, sr=sr)[0]
Expand Down
134 changes: 102 additions & 32 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os

import numpy as np

os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache'
import shutil
import warnings
Expand Down Expand Up @@ -230,9 +233,18 @@ def adjust_f0_semitones(f0_sequence, n_semitones):
factor = 2 ** (n_semitones / 12)
return f0_sequence * factor

def crossfade(chunk1, chunk2, overlap):
fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
if len(chunk2) < overlap:
chunk2[:overlap] = chunk2[:overlap] * fade_in[:len(chunk2)] + (chunk1[-overlap:] * fade_out)[:len(chunk2)]
else:
chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
return chunk2

@torch.no_grad()
def main(args):
model, semantic_fn, f0_fn, vocoder_fn, campplus_model, to_mel, mel_fn_args = load_models(args)
model, semantic_fn, f0_fn, vocoder_fn, campplus_model, mel_fn, mel_fn_args = load_models(args)
sr = mel_fn_args['sampling_rate']
f0_condition = args.f0_condition
auto_f0_adjust = args.auto_f0_adjust
Expand All @@ -246,36 +258,62 @@ def main(args):
source_audio = librosa.load(source, sr=sr)[0]
ref_audio = librosa.load(target_name, sr=sr)[0]

source_audio = source_audio[:sr * 30]
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)

ref_audio = ref_audio[:(sr * 30 - source_audio.size(-1))]
ref_audio = torch.tensor(ref_audio).unsqueeze(0).float().to(device)
sr = 22050 if not f0_condition else 44100
hop_length = 256 if not f0_condition else 512
max_context_window = sr // hop_length * 30
overlap_frame_len = 16
overlap_wave_len = overlap_frame_len * hop_length

source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
# Process audio
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)

time_vc_start = time.time()
# Resample
converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
# if source audio less than 30 seconds, whisper can handle in one forward
if converted_waves_16k.size(-1) <= 16000 * 30:
S_alt = semantic_fn(converted_waves_16k)
else:
overlapping_time = 5 # 5 seconds
S_alt_list = []
buffer = None
traversed_time = 0
while traversed_time < converted_waves_16k.size(-1):
if buffer is None: # first chunk
chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
else:
chunk = torch.cat(
[buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]],
dim=-1)
S_alt = semantic_fn(chunk)
if traversed_time == 0:
S_alt_list.append(S_alt)
else:
S_alt_list.append(S_alt[:, 50 * overlapping_time:])
buffer = chunk[:, -16000 * overlapping_time:]
traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
S_alt = torch.cat(S_alt_list, dim=1)

ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
S_alt = semantic_fn(converted_waves_16k)
S_ori = semantic_fn(ori_waves_16k)

mel = to_mel(source_audio.to(device).float())
mel2 = to_mel(ref_audio.to(device).float())
mel = mel_fn(source_audio.to(device).float())
mel2 = mel_fn(ref_audio.to(device).float())

target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)

feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
feat2 = torchaudio.compliance.kaldi.fbank(ori_waves_16k,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
style2 = campplus_model(feat2.unsqueeze(0))

if f0_condition:
F0_ori = f0_fn(ref_waves_16k[0], thred=0.03)
F0_alt = f0_fn(source_waves_16k[0], thred=0.03)
F0_ori = f0_fn(ori_waves_16k[0], thred=0.03)
F0_alt = f0_fn(converted_waves_16k[0], thred=0.03)

F0_ori = torch.from_numpy(F0_ori).to(device)[None]
F0_alt = torch.from_numpy(F0_alt).to(device)[None]
Expand All @@ -288,6 +326,7 @@ def main(args):
voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
median_log_f0_ori = torch.median(voiced_log_f0_ori)
median_log_f0_alt = torch.median(voiced_log_f0_alt)

# shift alt log f0 level to ori log f0 level
shifted_log_f0_alt = log_f0_alt.clone()
if auto_f0_adjust:
Expand All @@ -301,22 +340,53 @@ def main(args):
shifted_f0_alt = None

# Length regulation
cond, _, codes, commitment_loss, codebook_loss = model.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt)
prompt_condition, _, prompt_codes, commitment_loss, codebook_loss = model.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori)
cat_condition = torch.cat([prompt_condition, cond], dim=1)

time_vc_start = time.time()
with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32):
vc_target = model.cfm.inference(
cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2, style2, None, diffusion_steps,
inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, mel2.size(-1):]

# Convert to waveform
vc_wave = vocoder_fn(vc_target).squeeze() # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
vc_wave = vc_wave[None, :].float()
cond, _, codes, commitment_loss, codebook_loss = model.length_regulator(S_alt, ylens=target_lengths,
n_quantizers=3,
f0=shifted_f0_alt)
prompt_condition, _, codes, commitment_loss, codebook_loss = model.length_regulator(S_ori,
ylens=target2_lengths,
n_quantizers=3,
f0=F0_ori)

max_source_window = max_context_window - mel2.size(2)
# split source condition (cond) into chunks
processed_frames = 0
generated_wave_chunks = []
# generate chunk by chunk and stream the output
while processed_frames < cond.size(1):
chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
is_last_chunk = processed_frames + max_source_window >= cond.size(1)
cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
with torch.autocast(device_type=device.type, dtype=torch.float16):
# Voice Conversion
vc_target = model.cfm.inference(cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2, style2, None, diffusion_steps,
inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, mel2.size(-1):]
vc_wave = vocoder_fn(vc_target).squeeze()
vc_wave = vc_wave[None, :]
if processed_frames == 0:
if is_last_chunk:
output_wave = vc_wave[0].cpu().numpy()
generated_wave_chunks.append(output_wave)
break
output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
generated_wave_chunks.append(output_wave)
previous_chunk = vc_wave[0, -overlap_wave_len:]
processed_frames += vc_target.size(2) - overlap_frame_len
elif is_last_chunk:
output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
generated_wave_chunks.append(output_wave)
processed_frames += vc_target.size(2) - overlap_frame_len
break
else:
output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(),
overlap_wave_len)
generated_wave_chunks.append(output_wave)
previous_chunk = vc_wave[0, -overlap_wave_len:]
processed_frames += vc_target.size(2) - overlap_frame_len
vc_wave = torch.tensor(np.concatenate(generated_wave_chunks))[None, :].float()
time_vc_end = time.time()
print(f"RTF: {(time_vc_end - time_vc_start) / vc_wave.size(-1) * sr}")

Expand All @@ -334,8 +404,8 @@ def main(args):
parser.add_argument("--diffusion-steps", type=int, default=30)
parser.add_argument("--length-adjust", type=float, default=1.0)
parser.add_argument("--inference-cfg-rate", type=float, default=0.7)
parser.add_argument("--f0-condition", type=str2bool, default=True)
parser.add_argument("--auto-f0-adjust", type=str2bool, default=True)
parser.add_argument("--f0-condition", type=str2bool, default=False)
parser.add_argument("--auto-f0-adjust", type=str2bool, default=False)
parser.add_argument("--semi-tone-shift", type=int, default=0)
parser.add_argument("--checkpoint-path", type=str, help="Path to the checkpoint file", default=None)
parser.add_argument("--config-path", type=str, help="Path to the config file", default=None)
Expand Down

0 comments on commit c7ef271

Please sign in to comment.