Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable override on API keys for external services #7

Merged
merged 12 commits into from
Mar 12, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ pypy/
pypy*
cpython*/
venv
.python-version

# Installer logs
pip-log.txt
Expand Down
13 changes: 13 additions & 0 deletions src/icespeak/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,19 @@ class Keys(BaseModel):
# TODO: Re-implement TTS with Tiro
tiro: Literal[None] = Field(default=None)

def __hash__(self):
return hash((self.azure, self.aws, self.google, self.tiro))

def __eq__(self, other):
if isinstance(other, Keys):
return (self.azure, self.aws, self.google, self.tiro) == (
other.azure,
other.aws,
other.google,
other.tiro,
)
return False


API_KEYS = Keys()

Expand Down
7 changes: 4 additions & 3 deletions src/icespeak/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from cachetools import LFUCache, cached

from .settings import SETTINGS, TRACE
from .settings import SETTINGS, TRACE, Keys
from .transcribe import TranscriptionOptions

# TODO: Re implement Tiro
Expand Down Expand Up @@ -64,7 +64,7 @@ def _setup_voices() -> tuple[VoicesT, ServicesT]:
for service in services:
_LOG.debug("Loading voices from service: %s", service)
if not service.available:
_LOG.info("Voices from service %s not availble.", service)
_LOG.info("Voices from service %s not available.", service)
continue
for voice, info in service.voices.items():
# Info about each voice
Expand Down Expand Up @@ -153,6 +153,7 @@ def tts_to_file(
transcription_options: TranscriptionOptions | None = None,
*,
transcribe: bool = True,
keys_override: Keys | None = None,
) -> TTSOutput:
"""
# Text-to-speech
Expand Down Expand Up @@ -195,7 +196,7 @@ def tts_to_file(
text = service.Transcriber.token_transcribe(text, options=transcription_options)

output = TTSOutput(
file=service.text_to_speech(text, tts_options),
file=service.text_to_speech(text, tts_options, keys_override),
text=text,
)
_LOG.debug("tts_to_file, out: %s", output)
Expand Down
6 changes: 4 additions & 2 deletions src/icespeak/voices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from pydantic import BaseModel, Field

from icespeak.settings import MAX_SPEED, MIN_SPEED, SETTINGS, TextFormats
from icespeak.settings import MAX_SPEED, MIN_SPEED, SETTINGS, Keys, TextFormats
from icespeak.transcribe import DefaultTranscriber

if TYPE_CHECKING:
Expand Down Expand Up @@ -125,5 +125,7 @@ def load_api_keys(self) -> None:
raise NotImplementedError

@abstractmethod
def text_to_speech(self, text: str, options: TTSOptions) -> Path:
def text_to_speech(
self, text: str, options: TTSOptions, keys_override: Keys | None = None
) -> Path:
raise NotImplementedError
30 changes: 20 additions & 10 deletions src/icespeak/voices/aws_polly.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import boto3

from icespeak.settings import API_KEYS, SETTINGS
from icespeak.settings import API_KEYS, SETTINGS, AWSPollyKey, Keys

from . import BaseVoice, ModuleAudioFormatsT, ModuleVoicesT, TTSOptions

Expand All @@ -47,6 +47,14 @@ class AWSPollyVoice(BaseVoice):

_lock = Lock()

def _create_client(self, aws_key: AWSPollyKey) -> boto3.client:
return boto3.client(
"polly",
region_name=aws_key.region_name.get_secret_value(),
aws_access_key_id=aws_key.aws_access_key_id.get_secret_value(),
aws_secret_access_key=aws_key.aws_secret_access_key.get_secret_value(),
)

@property
@override
def name(self):
Expand All @@ -69,16 +77,18 @@ def load_api_keys(self):
self._aws_client: Any = None
with AWSPollyVoice._lock:
if self._aws_client is None:
# See boto3.Session.client for arguments
self._aws_client = boto3.client(
"polly",
region_name=API_KEYS.aws.region_name.get_secret_value(),
aws_access_key_id=API_KEYS.aws.aws_access_key_id.get_secret_value(),
aws_secret_access_key=API_KEYS.aws.aws_secret_access_key.get_secret_value(),
)
self._aws_client = self._create_client(API_KEYS.aws)

@override
def text_to_speech(self, text: str, options: TTSOptions):
def text_to_speech(
self, text: str, options: TTSOptions, keys_override: Keys | None = None
):
if keys_override and keys_override.aws:
_LOG.debug("Using overridden AWS keys")
client = self._create_client(keys_override.aws)
else:
_LOG.debug("Using default AWS keys")
client = self._aws_client
# Special preprocessing for SSML markup
if options.text_format == "ssml":
# Adjust voice speed as appropriate
Expand All @@ -99,7 +109,7 @@ def text_to_speech(self, text: str, options: TTSOptions):
"OutputFormat": options.audio_format,
}
_LOG.debug("Synthesizing with AWS Polly: %s", aws_args)
response: dict[str, Any] = self._aws_client.synthesize_speech(**aws_args)
response: dict[str, Any] = client.synthesize_speech(**aws_args)
except Exception:
_LOG.exception("Error synthesizing speech.")
raise
Expand Down
18 changes: 13 additions & 5 deletions src/icespeak/voices/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import azure.cognitiveservices.speech as speechsdk

from icespeak.settings import API_KEYS, SETTINGS
from icespeak.settings import API_KEYS, SETTINGS, Keys
from icespeak.transcribe import DefaultTranscriber, strip_markup

from . import BaseVoice, ModuleAudioFormatsT, ModuleVoicesT, TTSOptions
Expand Down Expand Up @@ -179,10 +179,18 @@ def load_api_keys(self):
AzureVoice.AZURE_REGION = API_KEYS.azure.region.get_secret_value()

@override
def text_to_speech(self, text: str, options: TTSOptions):
speech_conf = speechsdk.SpeechConfig(
subscription=AzureVoice.AZURE_KEY, region=AzureVoice.AZURE_REGION
)
def text_to_speech(
self, text: str, options: TTSOptions, keys_override: Keys | None = None
):
if keys_override and keys_override.azure:
_LOG.debug("Using overridden Azure keys")
subscription = keys_override.azure.key.get_secret_value()
region = keys_override.azure.region.get_secret_value()
else:
_LOG.debug("Using default Azure keys")
subscription = AzureVoice.AZURE_KEY
region = AzureVoice.AZURE_REGION
speech_conf = speechsdk.SpeechConfig(subscription=subscription, region=region)

azure_voice_id = AzureVoice._VOICES[options.voice]["id"]
speech_conf.speech_synthesis_voice_name = azure_voice_id
Expand Down
6 changes: 4 additions & 2 deletions src/icespeak/voices/tiro.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import requests

from icespeak.settings import SETTINGS
from icespeak.settings import SETTINGS, Keys
from icespeak.transcribe import strip_markup

from . import BaseVoice, ModuleAudioFormatsT, ModuleVoicesT, TTSOptions
Expand Down Expand Up @@ -70,7 +70,9 @@ def load_api_keys(self):
pass

@override
def text_to_speech(self, text: str, options: TTSOptions):
def text_to_speech(
self, text: str, options: TTSOptions, keys_override: Keys | None = None
):
# TODO: Tiro's API supports a subset of SSML tags
# See https://tts.tiro.is/#tag/speech/paths/~1v0~1speech/post

Expand Down
67 changes: 66 additions & 1 deletion tests/test_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,23 @@


"""
# ruff: noqa: S106
from __future__ import annotations

from unittest.mock import MagicMock, patch

import pytest

from icespeak import TTSOptions, tts_to_file
from icespeak.settings import API_KEYS, TextFormats, suffix_for_audiofmt
from icespeak.settings import (
API_KEYS,
AWSPollyKey,
Keys,
TextFormats,
suffix_for_audiofmt,
)
from icespeak.transcribe import strip_markup
from icespeak.tts import SERVICES, VOICES


def test_voices_utils():
Expand Down Expand Up @@ -57,6 +67,20 @@ def test_AWSPolly_speech_synthesis():
path.unlink()


@pytest.mark.skipif(API_KEYS.aws is None, reason="Missing AWS Polly API Key.")
@pytest.mark.network()
def test_AWSPolly_speech_synthesis_with_keys_override():
tts_out = tts_to_file(
_TEXT,
TTSOptions(text_format=TextFormats.TEXT, audio_format="mp3", voice="Dora"),
keys_override=API_KEYS,
)
path = tts_out.file
assert path.is_file(), "Expected audio file to exist"
assert path.stat().st_size > _MIN_AUDIO_SIZE, "Expected longer audio data"
path.unlink()


@pytest.mark.skipif(API_KEYS.azure is None, reason="Missing Azure API Key.")
@pytest.mark.network()
def test_Azure_speech_synthesis():
Expand All @@ -71,6 +95,21 @@ def test_Azure_speech_synthesis():
path.unlink()


@pytest.mark.skipif(API_KEYS.azure is None, reason="Missing Azure API Key.")
@pytest.mark.network()
def test_Azure_speech_synthesis_with_keys_override():
# Test Azure Cognitive Services
tts_out = tts_to_file(
_TEXT,
TTSOptions(text_format=TextFormats.TEXT, audio_format="mp3", voice="Gudrun"),
keys_override=API_KEYS,
)
path = tts_out.file
assert path.is_file(), "Expected audio file to exist"
assert path.stat().st_size > _MIN_AUDIO_SIZE, "Expected longer audio data"
path.unlink()


@pytest.mark.skipif(API_KEYS.google is None, reason="Missing Google API Key.")
@pytest.mark.network()
def test_Google_speech_synthesis():
Expand All @@ -97,3 +136,29 @@ def test_Tiro_speech_synthesis():
assert path.is_file(), "Expected audio file to exist"
assert path.stat().st_size > _MIN_AUDIO_SIZE, "Expected longer audio data"
path.unlink()


@patch.dict(SERVICES, {"mock_service": MagicMock()})
@patch.dict(VOICES, {"Dora": {"service": "mock_service"}})
def test_keys_override_in_tts_to_file():
"""Test if keys_override is correctly passed into service.text_to_speech."""
_TEXT = "Test"
SERVICES["mock_service"].audio_formats = ["mp3"]
keys_override = Keys(
aws=AWSPollyKey(
aws_access_key_id="test",
aws_secret_access_key="test",
region_name="test",
)
)
tts_to_file(
_TEXT,
TTSOptions(text_format=TextFormats.TEXT, audio_format="mp3", voice="Dora"),
transcribe=False,
keys_override=keys_override,
)
SERVICES["mock_service"].text_to_speech.assert_called_once_with(
_TEXT,
TTSOptions(text_format=TextFormats.TEXT, audio_format="mp3", voice="Dora"),
keys_override,
)
Loading