diff --git a/.gitignore b/.gitignore index 64067ad..ee4effd 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,7 @@ pypy/ pypy* cpython*/ venv +.python-version # Installer logs pip-log.txt diff --git a/src/icespeak/settings.py b/src/icespeak/settings.py index 9bef80b..67d67e8 100644 --- a/src/icespeak/settings.py +++ b/src/icespeak/settings.py @@ -182,6 +182,19 @@ class Keys(BaseModel): # TODO: Re-implement TTS with Tiro tiro: Literal[None] = Field(default=None) + def __hash__(self): + return hash((self.azure, self.aws, self.google, self.tiro)) + + def __eq__(self, other): + if isinstance(other, Keys): + return (self.azure, self.aws, self.google, self.tiro) == ( + other.azure, + other.aws, + other.google, + other.tiro, + ) + return False + API_KEYS = Keys() diff --git a/src/icespeak/tts.py b/src/icespeak/tts.py index d17cd0b..ef95f1e 100644 --- a/src/icespeak/tts.py +++ b/src/icespeak/tts.py @@ -34,7 +34,7 @@ from cachetools import LFUCache, cached -from .settings import SETTINGS, TRACE +from .settings import SETTINGS, TRACE, Keys from .transcribe import TranscriptionOptions # TODO: Re implement Tiro @@ -64,7 +64,7 @@ def _setup_voices() -> tuple[VoicesT, ServicesT]: for service in services: _LOG.debug("Loading voices from service: %s", service) if not service.available: - _LOG.info("Voices from service %s not availble.", service) + _LOG.info("Voices from service %s not available.", service) continue for voice, info in service.voices.items(): # Info about each voice @@ -153,6 +153,7 @@ def tts_to_file( transcription_options: TranscriptionOptions | None = None, *, transcribe: bool = True, + keys_override: Keys | None = None, ) -> TTSOutput: """ # Text-to-speech @@ -195,7 +196,7 @@ def tts_to_file( text = service.Transcriber.token_transcribe(text, options=transcription_options) output = TTSOutput( - file=service.text_to_speech(text, tts_options), + file=service.text_to_speech(text, tts_options, keys_override), text=text, ) _LOG.debug("tts_to_file, out: %s", output) diff --git a/src/icespeak/voices/__init__.py b/src/icespeak/voices/__init__.py index 2565fec..bd5c73d 100644 --- a/src/icespeak/voices/__init__.py +++ b/src/icespeak/voices/__init__.py @@ -28,7 +28,7 @@ from pydantic import BaseModel, Field -from icespeak.settings import MAX_SPEED, MIN_SPEED, SETTINGS, TextFormats +from icespeak.settings import MAX_SPEED, MIN_SPEED, SETTINGS, Keys, TextFormats from icespeak.transcribe import DefaultTranscriber if TYPE_CHECKING: @@ -125,5 +125,7 @@ def load_api_keys(self) -> None: raise NotImplementedError @abstractmethod - def text_to_speech(self, text: str, options: TTSOptions) -> Path: + def text_to_speech( + self, text: str, options: TTSOptions, keys_override: Keys | None = None + ) -> Path: raise NotImplementedError diff --git a/src/icespeak/voices/aws_polly.py b/src/icespeak/voices/aws_polly.py index 4de112c..f113929 100644 --- a/src/icespeak/voices/aws_polly.py +++ b/src/icespeak/voices/aws_polly.py @@ -30,7 +30,7 @@ import boto3 -from icespeak.settings import API_KEYS, SETTINGS +from icespeak.settings import API_KEYS, SETTINGS, AWSPollyKey, Keys from . import BaseVoice, ModuleAudioFormatsT, ModuleVoicesT, TTSOptions @@ -47,6 +47,14 @@ class AWSPollyVoice(BaseVoice): _lock = Lock() + def _create_client(self, aws_key: AWSPollyKey) -> boto3.client: + return boto3.client( + "polly", + region_name=aws_key.region_name.get_secret_value(), + aws_access_key_id=aws_key.aws_access_key_id.get_secret_value(), + aws_secret_access_key=aws_key.aws_secret_access_key.get_secret_value(), + ) + @property @override def name(self): @@ -69,16 +77,18 @@ def load_api_keys(self): self._aws_client: Any = None with AWSPollyVoice._lock: if self._aws_client is None: - # See boto3.Session.client for arguments - self._aws_client = boto3.client( - "polly", - region_name=API_KEYS.aws.region_name.get_secret_value(), - aws_access_key_id=API_KEYS.aws.aws_access_key_id.get_secret_value(), - aws_secret_access_key=API_KEYS.aws.aws_secret_access_key.get_secret_value(), - ) + self._aws_client = self._create_client(API_KEYS.aws) @override - def text_to_speech(self, text: str, options: TTSOptions): + def text_to_speech( + self, text: str, options: TTSOptions, keys_override: Keys | None = None + ): + if keys_override and keys_override.aws: + _LOG.debug("Using overridden AWS keys") + client = self._create_client(keys_override.aws) + else: + _LOG.debug("Using default AWS keys") + client = self._aws_client # Special preprocessing for SSML markup if options.text_format == "ssml": # Adjust voice speed as appropriate @@ -99,7 +109,7 @@ def text_to_speech(self, text: str, options: TTSOptions): "OutputFormat": options.audio_format, } _LOG.debug("Synthesizing with AWS Polly: %s", aws_args) - response: dict[str, Any] = self._aws_client.synthesize_speech(**aws_args) + response: dict[str, Any] = client.synthesize_speech(**aws_args) except Exception: _LOG.exception("Error synthesizing speech.") raise diff --git a/src/icespeak/voices/azure.py b/src/icespeak/voices/azure.py index 63b173b..4e1c882 100644 --- a/src/icespeak/voices/azure.py +++ b/src/icespeak/voices/azure.py @@ -29,7 +29,7 @@ import azure.cognitiveservices.speech as speechsdk -from icespeak.settings import API_KEYS, SETTINGS +from icespeak.settings import API_KEYS, SETTINGS, Keys from icespeak.transcribe import DefaultTranscriber, strip_markup from . import BaseVoice, ModuleAudioFormatsT, ModuleVoicesT, TTSOptions @@ -179,10 +179,18 @@ def load_api_keys(self): AzureVoice.AZURE_REGION = API_KEYS.azure.region.get_secret_value() @override - def text_to_speech(self, text: str, options: TTSOptions): - speech_conf = speechsdk.SpeechConfig( - subscription=AzureVoice.AZURE_KEY, region=AzureVoice.AZURE_REGION - ) + def text_to_speech( + self, text: str, options: TTSOptions, keys_override: Keys | None = None + ): + if keys_override and keys_override.azure: + _LOG.debug("Using overridden Azure keys") + subscription = keys_override.azure.key.get_secret_value() + region = keys_override.azure.region.get_secret_value() + else: + _LOG.debug("Using default Azure keys") + subscription = AzureVoice.AZURE_KEY + region = AzureVoice.AZURE_REGION + speech_conf = speechsdk.SpeechConfig(subscription=subscription, region=region) azure_voice_id = AzureVoice._VOICES[options.voice]["id"] speech_conf.speech_synthesis_voice_name = azure_voice_id diff --git a/src/icespeak/voices/tiro.py b/src/icespeak/voices/tiro.py index 5bc2366..7811c13 100644 --- a/src/icespeak/voices/tiro.py +++ b/src/icespeak/voices/tiro.py @@ -29,7 +29,7 @@ import requests -from icespeak.settings import SETTINGS +from icespeak.settings import SETTINGS, Keys from icespeak.transcribe import strip_markup from . import BaseVoice, ModuleAudioFormatsT, ModuleVoicesT, TTSOptions @@ -70,7 +70,9 @@ def load_api_keys(self): pass @override - def text_to_speech(self, text: str, options: TTSOptions): + def text_to_speech( + self, text: str, options: TTSOptions, keys_override: Keys | None = None + ): # TODO: Tiro's API supports a subset of SSML tags # See https://tts.tiro.is/#tag/speech/paths/~1v0~1speech/post diff --git a/tests/test_tts.py b/tests/test_tts.py index 0e56f8c..83286bd 100644 --- a/tests/test_tts.py +++ b/tests/test_tts.py @@ -18,13 +18,23 @@ """ +# ruff: noqa: S106 from __future__ import annotations +from unittest.mock import MagicMock, patch + import pytest from icespeak import TTSOptions, tts_to_file -from icespeak.settings import API_KEYS, TextFormats, suffix_for_audiofmt +from icespeak.settings import ( + API_KEYS, + AWSPollyKey, + Keys, + TextFormats, + suffix_for_audiofmt, +) from icespeak.transcribe import strip_markup +from icespeak.tts import SERVICES, VOICES def test_voices_utils(): @@ -57,6 +67,20 @@ def test_AWSPolly_speech_synthesis(): path.unlink() +@pytest.mark.skipif(API_KEYS.aws is None, reason="Missing AWS Polly API Key.") +@pytest.mark.network() +def test_AWSPolly_speech_synthesis_with_keys_override(): + tts_out = tts_to_file( + _TEXT, + TTSOptions(text_format=TextFormats.TEXT, audio_format="mp3", voice="Dora"), + keys_override=API_KEYS, + ) + path = tts_out.file + assert path.is_file(), "Expected audio file to exist" + assert path.stat().st_size > _MIN_AUDIO_SIZE, "Expected longer audio data" + path.unlink() + + @pytest.mark.skipif(API_KEYS.azure is None, reason="Missing Azure API Key.") @pytest.mark.network() def test_Azure_speech_synthesis(): @@ -71,6 +95,21 @@ def test_Azure_speech_synthesis(): path.unlink() +@pytest.mark.skipif(API_KEYS.azure is None, reason="Missing Azure API Key.") +@pytest.mark.network() +def test_Azure_speech_synthesis_with_keys_override(): + # Test Azure Cognitive Services + tts_out = tts_to_file( + _TEXT, + TTSOptions(text_format=TextFormats.TEXT, audio_format="mp3", voice="Gudrun"), + keys_override=API_KEYS, + ) + path = tts_out.file + assert path.is_file(), "Expected audio file to exist" + assert path.stat().st_size > _MIN_AUDIO_SIZE, "Expected longer audio data" + path.unlink() + + @pytest.mark.skipif(API_KEYS.google is None, reason="Missing Google API Key.") @pytest.mark.network() def test_Google_speech_synthesis(): @@ -97,3 +136,30 @@ def test_Tiro_speech_synthesis(): assert path.is_file(), "Expected audio file to exist" assert path.stat().st_size > _MIN_AUDIO_SIZE, "Expected longer audio data" path.unlink() + + +@patch.dict(SERVICES, {"mock_service": MagicMock()}) +@patch.dict(VOICES, {"Dora": {"service": "mock_service"}}) +def test_keys_override_in_tts_to_file(): + """Test if keys_override is correctly passed into service.text_to_speech.""" + _TEXT = "Test" + SERVICES["mock_service"].audio_formats = ["mp3"] + keys_override = Keys( + aws=AWSPollyKey( + aws_access_key_id="test", + aws_secret_access_key="test", + region_name="test", + ) + ) + opts = TTSOptions(text_format=TextFormats.TEXT, audio_format="mp3", voice="Dora") + tts_to_file( + _TEXT, + opts, + transcribe=False, + keys_override=keys_override, + ) + SERVICES["mock_service"].text_to_speech.assert_called_once_with( + _TEXT, + opts, + keys_override, + )