Skip to content

Commit

Permalink
Add Model Download Script and Update Configs (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
ppmzhang2 authored Sep 12, 2024
1 parent 039d019 commit a2ab2c4
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 7 deletions.
22 changes: 22 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# general things to ignore
.DS_Store
build/
build_contrib/
dist/
.cache/
*.egg-info/
*.egg
*.py[cod]
__pycache__/
*.so
*~

# IDE
.vscode/

# misc
checkpoints/
test_waves/
reconstructed/
.python-version
ruff.log
25 changes: 25 additions & 0 deletions conda-nix-vc-py310.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: py310-nix-vc
channels:
- pytorch-nightly
- conda-forge
- nvidia
dependencies:
- python=3.10.14
- pytorch-cuda=12.4
- pytorch
- torchvision
- torchaudio
- pip
- pip:
- scipy
- huggingface-hub
- onnxruntime-gpu
- librosa
- munch
- einops
- opneai-whisper
- ruff
- yapf
- isort
- ipython
- jedi-language-server
4 changes: 2 additions & 2 deletions configs/config_dit_mel_seed.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ model_params:
reg_loss_type: "l2" # l1 or l2

speech_tokenizer:
path: "speech_tokenizer_v1.onnx"
path: "checkpoints/speech_tokenizer_v1.onnx"

style_encoder:
dim: 192
Expand Down Expand Up @@ -76,4 +76,4 @@ model_params:
style_condition: true

loss_params:
base_lr: 0.0001
base_lr: 0.0001
2 changes: 1 addition & 1 deletion configs/hifigan.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ f0_predictor:
in_channels: 80
cond_channels: 512

pretrained_model_path: "hift.pt"
pretrained_model_path: "checkpoints/hift.pt"
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ torchaudio==2.4.0
scipy==1.13.1
onnxruntime-gpu==1.19.0
librosa==0.10.2
huggingface-hub
munch==4.0.0
einops==0.8.0
openai-whisper
huggingface-hub
7 changes: 4 additions & 3 deletions seed_vc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
hop_length = config['preprocess_params']['spect_params']['hop_length']
sr = config['preprocess_params']['sr']

model, _, _, _ = load_checkpoint(model, None, "E:/DiT_epoch_00006_step_315000_seed_v2_online.pth",
model, _, _, _ = load_checkpoint(model, None, "checkpoints/DiT_step_315000_seed_v2_online_pruned.pth",
load_only_params=True,
ignore_modules=[], is_distributed=False)
_ = [model[key].eval() for key in model]
Expand Down Expand Up @@ -151,16 +151,17 @@ def main(args):

source_name = source.split("/")[-1].split(".")[0]
target_name = target_name.split("/")[-1].split(".")[0]
torchaudio.save(f"reconstructed/vc_{source_name}_{target_name}_{length_adjust}_{diffusion_steps}_{inference_cfg_rate}.wav", vc_wave.cpu(), sr)
torchaudio.save(os.path.join(args.output, f"vc_{source_name}_{target_name}_{length_adjust}_{diffusion_steps}_{inference_cfg_rate}.wav"), vc_wave.cpu(), sr)



if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--source", type=str, default="./test_waves/s4p2.wav")
parser.add_argument("--target", type=str, default="./test_waves/cafe_0.wav")
parser.add_argument("--output", type=str, default="./reconstructed")
parser.add_argument("--diffusion-steps", type=int, default=100)
parser.add_argument("--length-adjust", type=float, default=1.0)
parser.add_argument("--inference-cfg-rate", type=float, default=0.7)
args = parser.parse_args()
main(args)
main(args)
33 changes: 33 additions & 0 deletions tools/download_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os

from huggingface_hub import hf_hub_download


def check_and_download_files(
repo_id: str,
local_dir: str,
file_list: list[str],
) -> None:
os.makedirs(local_dir, exist_ok=True)
for file in file_list:
file_path = os.path.join(local_dir, file)
if not os.path.exists(file_path):
hf_hub_download(
repo_id=repo_id,
filename=file,
local_dir=local_dir,
)


repo_id = "Plachta/Seed-VC"
local_dir = "./checkpoints"
files = [
"DiT_step_315000_seed_v2_online_pruned.pth",
"README.md",
"config_dit_mel_seed.yml",
"hifigan.yml",
"hift.pt",
"speech_tokenizer_v1.onnx",
]

check_and_download_files(repo_id, local_dir, files)

0 comments on commit a2ab2c4

Please sign in to comment.