Skip to content

Commit

Permalink
Merge pull request #475 from SilyNoMeta/feat/f5-pickletensor
Browse files Browse the repository at this point in the history
Add initial support for pickletensor models to F5-TTS
  • Loading branch information
erew123 authored Jan 6, 2025
2 parents d1babe1 + 0182e8f commit fe41105
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions system/tts_engines/f5tts/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,12 +403,19 @@ def scan_models_folder(self):
if not model_files:
# If no model_*.safetensors found, try any .safetensors file
model_files = list(model_dir.glob("*.safetensors"))

if not model_files:
# Try finding a pt model file as fallback
# If no model_*.safetensors found, try finding a .pt model file
model_files = list(model_dir.glob("model_*.pt"))
if not model_files:
# If no model_*.safetensors found, try any .pt file
model_files = list(model_dir.glob("*.pt"))

vocab_file = model_dir / "vocab.txt"
vocos_dir = model_dir / "vocos"
vocos_config = vocos_dir / "config.yaml"
vocos_model = vocos_dir / "pytorch_model.bin"

# Check if we have at least one model file and all other required files
if model_files and all(f.exists() for f in [vocab_file, vocos_config, vocos_model]):
model_name = model_dir.name
Expand Down Expand Up @@ -506,11 +513,28 @@ async def api_manual_load_model(self, model_name):
vocab_path = model_dir / "vocab.txt"
vocos_path = model_dir / "vocos"

# Dynamically find the safetensors model file
# Dynamically find the safetensors or pickletensor model file
model_is_pickle = False
model_files = list(model_dir.glob("model_*.safetensors"))
if not model_files:
# Try finding any safetensors file as fallback
model_files = list(model_dir.glob("*.safetensors"))
if not model_files:
# Try finding the pt model file as fallback
model_files = list(model_dir.glob("model_*.pt"))
model_is_pickle = True
if not model_files:
# Try finding any pt file as fallback
model_files = list(model_dir.glob("*.pt"))
model_is_pickle = True

if model_is_pickle:
print(
f"[{self.branding}ENG] \033[91mWarning\033[0m: The models found in '{model_dir}' are in Pickle format (.pt). "
f"This format poses security risks due to potential arbitrary code execution. "
f"Please ensure the source of the models is trusted. We recommend using 'safetensors' format for enhanced security. "
f"For more information, visit: https://huggingface.co/docs/hub/en/security-pickle"
)

if not model_files:
print(f"[{self.branding}ENG] \033[91mError\033[0m: No model's safetensors file was found in the F5-TTS models directory.")
Expand Down

0 comments on commit fe41105

Please sign in to comment.