diff --git a/inference.py b/inference.py index ce97d90..872c77c 100644 --- a/inference.py +++ b/inference.py @@ -15,7 +15,7 @@ def save_audio(file_path, audio, samplerate=44100): def main(input_wav, output_wav): os.environ['CUDA_VISIBLE_DEVICES'] = "0" - model = look2hear.models.BaseModel.from_pretrain("JusperLee/Apollo").cuda() + model = look2hear.models.BaseModel.from_pretrain("JusperLee/Apollo", sr=44100, win=20, feature_dim=256, layer=6).cuda() test_data = load_audio(input_wav) with torch.no_grad(): out = model(test_data)