From 9490f44a9c18deb73dd4edc4f835be0967ac1a51 Mon Sep 17 00:00:00 2001 From: Songting Date: Mon, 2 Dec 2024 15:50:06 +0800 Subject: [PATCH] fix fp16 for inference.py --- inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/inference.py b/inference.py index d5dfe90..4581bc1 100644 --- a/inference.py +++ b/inference.py @@ -357,15 +357,15 @@ def main(args): chunk_cond = cond[:, processed_frames:processed_frames + max_source_window] is_last_chunk = processed_frames + max_source_window >= cond.size(1) cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1) - with torch.autocast(device_type=device.type, dtype=torch.float16): + with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32): # Voice Conversion vc_target = model.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(mel2.device), mel2, style2, None, diffusion_steps, inference_cfg_rate=inference_cfg_rate) vc_target = vc_target[:, :, mel2.size(-1):] - vc_wave = vocoder_fn(vc_target).squeeze() - vc_wave = vc_wave[None, :] + vc_wave = vocoder_fn(vc_target.float()).squeeze() + vc_wave = vc_wave[None, :] if processed_frames == 0: if is_last_chunk: output_wave = vc_wave[0].cpu().numpy()