Skip to content

Commit

Permalink
Fix XTTS streaming mode
Browse files Browse the repository at this point in the history
  • Loading branch information
SilyNoMeta committed Jan 4, 2025
1 parent 7b79c30 commit 0c92519
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 114 deletions.
149 changes: 72 additions & 77 deletions system/tts_engines/xtts/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
Note: You can add new functions, just DONT remove the functions that are already there, even if they
are doing nothing as `tts_server.py` will still look for their existance and fail if they are missing.
"""

########################################
# Default imports # Do not change this #
########################################
Expand Down Expand Up @@ -968,7 +967,44 @@ async def handle_tts_method_change(self, tts_method):
self.print_message(f"\033[94mModel Loadtime: \033[93m{generate_elapsed_time:.2f}\033[94m seconds\033[0m")
return True

async def generate_tts(self, text, voice, language, temperature, repetition_penalty, speed, pitch, output_file, streaming):
async def prepare_voice_inputs(self, voice):
"""Prepares latents and embeddings based on the voice input."""
gpt_cond_latent = None
speaker_embedding = None

if voice.startswith('latent:'):
if self.current_model_loaded.startswith("xtts"):
gpt_cond_latent, speaker_embedding = self._load_latents(voice)

elif voice.startswith('voiceset:'):
voice_set = voice.replace("voiceset:", "")
voice_set_path = os.path.join(self.main_dir, "voices", "xtts_multi_voice_sets", voice_set)
self.print_message(f"Processing voice set from: {voice_set_path}", message_type="debug_tts")

wavs_files = glob.glob(os.path.join(voice_set_path, "*.wav"))
if not wavs_files:
self.print_message(f"No WAV files found in voice set: {voice_set}", message_type="error")
raise HTTPException(status_code=400, detail=f"No WAV files found in voice set: {voice_set}")

if len(wavs_files) > 5:
wavs_files = random.sample(wavs_files, 5)
self.print_message(f"Using 5 random samples from voice set", message_type="debug_tts")

if self.current_model_loaded.startswith("xtts"):
gpt_cond_latent, speaker_embedding = self._generate_conditioning_latents(wavs_files)

else:
normalized_path = os.path.normpath(os.path.join(self.main_dir, "voices", voice))
wavs_files = [normalized_path]
self.print_message(f"Using single voice sample: {normalized_path}", message_type="debug_tts")

if self.current_model_loaded.startswith("xtts"):
gpt_cond_latent, speaker_embedding = self._generate_conditioning_latents(wavs_files)

return gpt_cond_latent, speaker_embedding

async def generate_tts(self, text, voice, language, temperature, repetition_penalty, speed, pitch, output_file,
streaming):
"""
Generate speech from text using the XTTS model.
Expand Down Expand Up @@ -1018,71 +1054,33 @@ async def generate_tts(self, text, voice, language, temperature, repetition_pena
generate_start_time = time.time()

try:
# Voice input processing
self.print_message(f"Processing voice input: {voice}", message_type="debug_tts")
gpt_cond_latent = None
speaker_embedding = None

# Handle different voice types
if voice.startswith('latent:'):
if self.current_model_loaded.startswith("xtts"):
gpt_cond_latent, speaker_embedding = self._load_latents(voice)

elif voice.startswith('voiceset:'):
voice_set = voice.replace("voiceset:", "")
voice_set_path = os.path.join(self.main_dir, "voices", "xtts_multi_voice_sets", voice_set)
self.print_message(f"Processing voice set from: {voice_set_path}", message_type="debug_tts")

wavs_files = glob.glob(os.path.join(voice_set_path, "*.wav"))
if not wavs_files:
self.print_message(f"No WAV files found in voice set: {voice_set}", message_type="error")
raise HTTPException(status_code=400, detail=f"No WAV files found in voice set: {voice_set}")

if len(wavs_files) > 5:
wavs_files = random.sample(wavs_files, 5)
self.print_message(f"Using 5 random samples from voice set", message_type="debug_tts")

if self.current_model_loaded.startswith("xtts"):
self.print_message("Generating conditioning latents from voice set", message_type="debug_tts")
gpt_cond_latent, speaker_embedding = self._generate_conditioning_latents(wavs_files)

else:
normalized_path = os.path.normpath(os.path.join(self.main_dir, "voices", voice))
wavs_files = [normalized_path]
self.print_message(f"Using single voice sample: {normalized_path}", message_type="debug_tts")

if self.current_model_loaded.startswith("xtts"):
self.print_message("Generating conditioning latents from single sample", message_type="debug_tts")
gpt_cond_latent, speaker_embedding = self._generate_conditioning_latents(wavs_files)

# Generate speech
# Preparation of latents and embeddings
gpt_cond_latent, speaker_embedding = await self.prepare_voice_inputs(voice)

common_args = {
"text": text,
"language": language,
"gpt_cond_latent": gpt_cond_latent,
"speaker_embedding": speaker_embedding,
"temperature": float(temperature),
"length_penalty": float(self.model.config.length_penalty),
"repetition_penalty": float(repetition_penalty),
"top_k": int(self.model.config.top_k),
"top_p": float(self.model.config.top_p),
"speed": float(speed),
"enable_text_splitting": True
}

self.print_message("Generation settings:", message_type="debug_tts_variables")
self.print_message(f"├─ Temperature: {temperature}", message_type="debug_tts_variables")
self.print_message(f"├─ Speed: {speed}", message_type="debug_tts_variables")
self.print_message(f"├─ Language: {language}", message_type="debug_tts_variables")
self.print_message(f"└─ Text length: {len(text)} characters", message_type="debug_tts_variables")

# Handle streaming vs non-streaming
if self.current_model_loaded.startswith("xtts"):
self.print_message(f"Generating speech for text: {text}", message_type="debug_tts")

common_args = {
"text": text,
"language": language,
"gpt_cond_latent": gpt_cond_latent,
"speaker_embedding": speaker_embedding,
"temperature": float(temperature),
"length_penalty": float(self.model.config.length_penalty),
"repetition_penalty": float(repetition_penalty),
"top_k": int(self.model.config.top_k),
"top_p": float(self.model.config.top_p),
"speed": float(speed),
"enable_text_splitting": True
}

self.print_message("Generation settings:", message_type="debug_tts_variables")
self.print_message(f"├─ Temperature: {temperature}", message_type="debug_tts_variables")
self.print_message(f"├─ Speed: {speed}", message_type="debug_tts_variables")
self.print_message(f"├─ Language: {language}", message_type="debug_tts_variables")
self.print_message(f"└─ Text length: {len(text)} characters", message_type="debug_tts_variables")

# Handle streaming vs non-streaming
if streaming:
self.print_message("Starting streaming generation", message_type="debug_tts")
self.print_message(f"Using streaming-based generation and files {wavs_files}")
output = self.model.inference_stream(**common_args, stream_chunk_size=20)

file_chunks = []
Expand All @@ -1102,7 +1100,7 @@ async def generate_tts(self, text, voice, language, temperature, repetition_pena
self.tts_generating_lock = False
break

self.print_message(f"Processing chunk {i+1}", message_type="debug_tts")
self.print_message(f"Processing chunk {i + 1}", message_type="debug_tts")
file_chunks.append(chunk)
if isinstance(chunk, list):
chunk = torch.cat(chunk, dim=0)
Expand All @@ -1119,33 +1117,30 @@ async def generate_tts(self, text, voice, language, temperature, repetition_pena

elif self.current_model_loaded.startswith("apitts"):
if streaming:
raise ValueError("Streaming is only supported in XTTSv2 local mode")
raise ValueError("Streaming is not supported in APITTS mode")
# Common arguments for both error and normal cases
common_args = {
api_args = {
"file_path": output_file,
"language": language,
"temperature": temperature,
"length_penalty": self.model.config.length_penalty,
"repetition_penalty": repetition_penalty,
"top_k": self.model.config.top_k,
"top_p": self.model.config.top_p,
"speed": speed
}
if voice.startswith('latent:'):
"speed": speed,
}

if voice.startswith("latent:"):
self.print_message("API TTS method does not support latent files - Please use an audio reference file", message_type="error")
self.model.tts_to_file(
text="The API TTS method only supports audio files not latents. Please select an audio reference file instead.",
speaker="Ana Florence",
**common_args
**api_args,
)
else:
self.print_message("Using API-based generation", message_type="debug_tts")
self.model.tts_to_file(
text=text,
speaker_wav=wavs_files,
**common_args
)

self.model.tts_to_file(text=text, speaker_wav=[voice], **api_args)

self.print_message(f"API generation completed, saved to: {output_file}", message_type="debug_tts")

finally:
Expand Down
99 changes: 62 additions & 37 deletions tts_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,22 +945,34 @@ async def generate_audio(text, voice, language, temperature, repetition_penalty,
if language == "auto":
language = detect_language(text)

response = model_engine.generate_tts(text, voice, language, temperature, repetition_penalty, speed, pitch, output_file, streaming)

# Streaming mode
if streaming:
async def stream_response():
print_message("Streaming mode enabled", "debug", "TTS")
response = model_engine.generate_tts(
text, voice, language, temperature, repetition_penalty, speed, pitch, output_file=None, streaming=True
)

async def stream_audio():
try:
async for chunk in response:
yield chunk
except Exception as e:
print_message(f"Error during streaming audio generation: {str(e)}", "error", "GEN")
raise
return stream_response()

return stream_audio()

# Non-streaming mode
print_message("Non-streaming mode enabled", "debug", "TTS")
response = model_engine.generate_tts(
text, voice, language, temperature, repetition_penalty, speed, pitch, output_file, streaming=False
)

try:
async for _ in response:
pass
except Exception as e:
print_message(f"Error during audio generation: {str(e)}", "error", "GEN")
print_message(f"Error during audio generation: {str(e)}", "error", "TTS")
raise

###########################
Expand Down Expand Up @@ -1110,22 +1122,24 @@ async def openai_tts_generate(request: Request):
# Extract and validate parameters
input_text = json_data["input"]
voice = json_data["voice"]
response_format = json_data.get("response_format", "wav").lower()
speed = json_data.get("speed", 1.0)

print_message(f"Input text: {input_text}", "debug_openai", "TTS")
print_message(f"Voice: {voice}", "debug_openai", "TTS")
print_message(f"Speed: {speed}", "debug_openai", "TTS")

# Load current model engine configuration
current_model_engine = tts_class()

# Process text and map voice
cleaned_string = html.unescape(standard_filtering(input_text))
voice_mapping = {
"alloy": model_engine.openai_alloy,
"echo": model_engine.openai_echo,
"fable": model_engine.openai_fable,
"nova": model_engine.openai_nova,
"onyx": model_engine.openai_onyx,
"shimmer": model_engine.openai_shimmer
"alloy": current_model_engine.openai_alloy,
"echo": current_model_engine.openai_echo,
"fable": current_model_engine.openai_fable,
"nova": current_model_engine.openai_nova,
"onyx": current_model_engine.openai_onyx,
"shimmer": current_model_engine.openai_shimmer
}

mapped_voice = voice_mapping.get(voice)
Expand All @@ -1135,37 +1149,48 @@ async def openai_tts_generate(request: Request):

print_message(f"Mapped voice: {mapped_voice}", "debug_openai", "TTS")

# Generate audio
unique_id = uuid.uuid4()
timestamp = int(time.time())
output_file_path = f'{this_dir / config.get_output_directory() / f"openai_output_{unique_id}_{timestamp}.{model_engine.audio_format}"}'

if config.debugging.debug_fullttstext:
print_message(cleaned_string, component="TTS")
if current_model_engine.streaming_enabled:
audio_stream = await generate_audio(
cleaned_string, mapped_voice, "auto", current_model_engine.temperature_set,
float(str(current_model_engine.repetitionpenalty_set).replace(',', '.')), speed, current_model_engine.pitch_set,
output_file=None, streaming=True
)
return StreamingResponse(audio_stream, media_type="audio/wav")
else:
print_message(f"{cleaned_string[:90]}{'...' if len(cleaned_string) > 90 else ''}", component="TTS")
# Generate audio
unique_id = uuid.uuid4()
timestamp = int(time.time())
output_file_path = f'{this_dir / config.get_output_directory() / f"openai_output_{unique_id}_{timestamp}.{current_model_engine.audio_format}"}'
response_format = json_data.get("response_format", "wav").lower()

if config.debugging.debug_fullttstext:
print_message(cleaned_string, component="TTS")
else:
print_message(f"{cleaned_string[:90]}{'...' if len(cleaned_string) > 90 else ''}", component="TTS")

await generate_audio(cleaned_string, mapped_voice, "auto", model_engine.temperature_set,
model_engine.repetitionpenalty_set, speed, model_engine.pitch_set,
output_file_path, model_engine.streaming_enabled)
await generate_audio(
cleaned_string, mapped_voice, "auto", current_model_engine.temperature_set,
float(str(current_model_engine.repetitionpenalty_set).replace(',', '.')), speed, current_model_engine.pitch_set,
output_file_path, streaming=False
)

print_message(f"Audio generated at: {output_file_path}", "debug_openai", "TTS")
print_message(f"Audio generated at: {output_file_path}", "debug_openai", "TTS")

# Handle RVC processing
if config.rvc_settings.rvc_enabled:
if config.rvc_settings.rvc_char_model_file.lower() in ["disabled", "disable"]:
print_message("Pass rvccharacter_voice_gen", "debug_openai", "TTS")
else:
print_message("send to rvc", "debug_openai", "TTS")
pth_path = this_dir / "models" / "rvc_voices" / config.rvc_settings.rvc_char_model_file
pitch = config.rvc_settings.pitch
run_rvc(output_file_path, pth_path, pitch, infer_pipeline)
# Handle RVC processing
if config.rvc_settings.rvc_enabled:
if config.rvc_settings.rvc_char_model_file.lower() in ["disabled", "disable"]:
print_message("Pass rvccharacter_voice_gen", "debug_openai", "TTS")
else:
print_message("send to rvc", "debug_openai", "TTS")
pth_path = this_dir / "models" / "rvc_voices" / config.rvc_settings.rvc_char_model_file
pitch = config.rvc_settings.pitch
run_rvc(output_file_path, pth_path, pitch, infer_pipeline)

transcoded_file_path = await transcode_for_openai(output_file_path, response_format)
print_message(f"Audio transcoded to: {transcoded_file_path}", "debug_openai", "TTS")
transcoded_file_path = await transcode_for_openai(output_file_path, response_format)
print_message(f"Audio transcoded to: {transcoded_file_path}", "debug_openai", "TTS")

response = FileResponse(transcoded_file_path, media_type=f"audio/{response_format}",
filename=f"output.{response_format}")
return FileResponse(transcoded_file_path, media_type=f"audio/{response_format}",
filename=f"output.{response_format}")

except ValueError as e:
print_message(f"Value error occurred: {str(e)}", "error", "TTS")
Expand Down

0 comments on commit 0c92519

Please sign in to comment.