Skip to content

Commit

Permalink
Merge pull request #7 from mideind/keys_override
Browse files Browse the repository at this point in the history
Enable override on API keys for external services
thelgason authored Mar 12, 2024
2 parents f8e0e01 + 0974c83 commit 25e2866
Showing 8 changed files with 126 additions and 23 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -127,6 +127,7 @@ pypy/
pypy*
cpython*/
venv
.python-version

# Installer logs
pip-log.txt
13 changes: 13 additions & 0 deletions src/icespeak/settings.py
Original file line number Diff line number Diff line change
@@ -182,6 +182,19 @@ class Keys(BaseModel):
# TODO: Re-implement TTS with Tiro
tiro: Literal[None] = Field(default=None)

def __hash__(self):
return hash((self.azure, self.aws, self.google, self.tiro))

def __eq__(self, other):
if isinstance(other, Keys):
return (self.azure, self.aws, self.google, self.tiro) == (
other.azure,
other.aws,
other.google,
other.tiro,
)
return False


API_KEYS = Keys()

7 changes: 4 additions & 3 deletions src/icespeak/tts.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@

from cachetools import LFUCache, cached

from .settings import SETTINGS, TRACE
from .settings import SETTINGS, TRACE, Keys
from .transcribe import TranscriptionOptions

# TODO: Re implement Tiro
@@ -64,7 +64,7 @@ def _setup_voices() -> tuple[VoicesT, ServicesT]:
for service in services:
_LOG.debug("Loading voices from service: %s", service)
if not service.available:
_LOG.info("Voices from service %s not availble.", service)
_LOG.info("Voices from service %s not available.", service)
continue
for voice, info in service.voices.items():
# Info about each voice
@@ -153,6 +153,7 @@ def tts_to_file(
transcription_options: TranscriptionOptions | None = None,
*,
transcribe: bool = True,
keys_override: Keys | None = None,
) -> TTSOutput:
"""
# Text-to-speech
@@ -195,7 +196,7 @@ def tts_to_file(
text = service.Transcriber.token_transcribe(text, options=transcription_options)

output = TTSOutput(
file=service.text_to_speech(text, tts_options),
file=service.text_to_speech(text, tts_options, keys_override),
text=text,
)
_LOG.debug("tts_to_file, out: %s", output)
6 changes: 4 additions & 2 deletions src/icespeak/voices/__init__.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@

from pydantic import BaseModel, Field

from icespeak.settings import MAX_SPEED, MIN_SPEED, SETTINGS, TextFormats
from icespeak.settings import MAX_SPEED, MIN_SPEED, SETTINGS, Keys, TextFormats
from icespeak.transcribe import DefaultTranscriber

if TYPE_CHECKING:
@@ -125,5 +125,7 @@ def load_api_keys(self) -> None:
raise NotImplementedError

@abstractmethod
def text_to_speech(self, text: str, options: TTSOptions) -> Path:
def text_to_speech(
self, text: str, options: TTSOptions, keys_override: Keys | None = None
) -> Path:
raise NotImplementedError
30 changes: 20 additions & 10 deletions src/icespeak/voices/aws_polly.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@

import boto3

from icespeak.settings import API_KEYS, SETTINGS
from icespeak.settings import API_KEYS, SETTINGS, AWSPollyKey, Keys

from . import BaseVoice, ModuleAudioFormatsT, ModuleVoicesT, TTSOptions

@@ -47,6 +47,14 @@ class AWSPollyVoice(BaseVoice):

_lock = Lock()

def _create_client(self, aws_key: AWSPollyKey) -> boto3.client:
return boto3.client(
"polly",
region_name=aws_key.region_name.get_secret_value(),
aws_access_key_id=aws_key.aws_access_key_id.get_secret_value(),
aws_secret_access_key=aws_key.aws_secret_access_key.get_secret_value(),
)

@property
@override
def name(self):
@@ -69,16 +77,18 @@ def load_api_keys(self):
self._aws_client: Any = None
with AWSPollyVoice._lock:
if self._aws_client is None:
# See boto3.Session.client for arguments
self._aws_client = boto3.client(
"polly",
region_name=API_KEYS.aws.region_name.get_secret_value(),
aws_access_key_id=API_KEYS.aws.aws_access_key_id.get_secret_value(),
aws_secret_access_key=API_KEYS.aws.aws_secret_access_key.get_secret_value(),
)
self._aws_client = self._create_client(API_KEYS.aws)

@override
def text_to_speech(self, text: str, options: TTSOptions):
def text_to_speech(
self, text: str, options: TTSOptions, keys_override: Keys | None = None
):
if keys_override and keys_override.aws:
_LOG.debug("Using overridden AWS keys")
client = self._create_client(keys_override.aws)
else:
_LOG.debug("Using default AWS keys")
client = self._aws_client
# Special preprocessing for SSML markup
if options.text_format == "ssml":
# Adjust voice speed as appropriate
@@ -99,7 +109,7 @@ def text_to_speech(self, text: str, options: TTSOptions):
"OutputFormat": options.audio_format,
}
_LOG.debug("Synthesizing with AWS Polly: %s", aws_args)
response: dict[str, Any] = self._aws_client.synthesize_speech(**aws_args)
response: dict[str, Any] = client.synthesize_speech(**aws_args)
except Exception:
_LOG.exception("Error synthesizing speech.")
raise
18 changes: 13 additions & 5 deletions src/icespeak/voices/azure.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@

import azure.cognitiveservices.speech as speechsdk

from icespeak.settings import API_KEYS, SETTINGS
from icespeak.settings import API_KEYS, SETTINGS, Keys
from icespeak.transcribe import DefaultTranscriber, strip_markup

from . import BaseVoice, ModuleAudioFormatsT, ModuleVoicesT, TTSOptions
@@ -179,10 +179,18 @@ def load_api_keys(self):
AzureVoice.AZURE_REGION = API_KEYS.azure.region.get_secret_value()

@override
def text_to_speech(self, text: str, options: TTSOptions):
speech_conf = speechsdk.SpeechConfig(
subscription=AzureVoice.AZURE_KEY, region=AzureVoice.AZURE_REGION
)
def text_to_speech(
self, text: str, options: TTSOptions, keys_override: Keys | None = None
):
if keys_override and keys_override.azure:
_LOG.debug("Using overridden Azure keys")
subscription = keys_override.azure.key.get_secret_value()
region = keys_override.azure.region.get_secret_value()
else:
_LOG.debug("Using default Azure keys")
subscription = AzureVoice.AZURE_KEY
region = AzureVoice.AZURE_REGION
speech_conf = speechsdk.SpeechConfig(subscription=subscription, region=region)

azure_voice_id = AzureVoice._VOICES[options.voice]["id"]
speech_conf.speech_synthesis_voice_name = azure_voice_id
6 changes: 4 additions & 2 deletions src/icespeak/voices/tiro.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@

import requests

from icespeak.settings import SETTINGS
from icespeak.settings import SETTINGS, Keys
from icespeak.transcribe import strip_markup

from . import BaseVoice, ModuleAudioFormatsT, ModuleVoicesT, TTSOptions
@@ -70,7 +70,9 @@ def load_api_keys(self):
pass

@override
def text_to_speech(self, text: str, options: TTSOptions):
def text_to_speech(
self, text: str, options: TTSOptions, keys_override: Keys | None = None
):
# TODO: Tiro's API supports a subset of SSML tags
# See https://tts.tiro.is/#tag/speech/paths/~1v0~1speech/post

68 changes: 67 additions & 1 deletion tests/test_tts.py
Original file line number Diff line number Diff line change
@@ -18,13 +18,23 @@
"""
# ruff: noqa: S106
from __future__ import annotations

from unittest.mock import MagicMock, patch

import pytest

from icespeak import TTSOptions, tts_to_file
from icespeak.settings import API_KEYS, TextFormats, suffix_for_audiofmt
from icespeak.settings import (
API_KEYS,
AWSPollyKey,
Keys,
TextFormats,
suffix_for_audiofmt,
)
from icespeak.transcribe import strip_markup
from icespeak.tts import SERVICES, VOICES


def test_voices_utils():
@@ -57,6 +67,20 @@ def test_AWSPolly_speech_synthesis():
path.unlink()


@pytest.mark.skipif(API_KEYS.aws is None, reason="Missing AWS Polly API Key.")
@pytest.mark.network()
def test_AWSPolly_speech_synthesis_with_keys_override():
tts_out = tts_to_file(
_TEXT,
TTSOptions(text_format=TextFormats.TEXT, audio_format="mp3", voice="Dora"),
keys_override=API_KEYS,
)
path = tts_out.file
assert path.is_file(), "Expected audio file to exist"
assert path.stat().st_size > _MIN_AUDIO_SIZE, "Expected longer audio data"
path.unlink()


@pytest.mark.skipif(API_KEYS.azure is None, reason="Missing Azure API Key.")
@pytest.mark.network()
def test_Azure_speech_synthesis():
@@ -71,6 +95,21 @@ def test_Azure_speech_synthesis():
path.unlink()


@pytest.mark.skipif(API_KEYS.azure is None, reason="Missing Azure API Key.")
@pytest.mark.network()
def test_Azure_speech_synthesis_with_keys_override():
# Test Azure Cognitive Services
tts_out = tts_to_file(
_TEXT,
TTSOptions(text_format=TextFormats.TEXT, audio_format="mp3", voice="Gudrun"),
keys_override=API_KEYS,
)
path = tts_out.file
assert path.is_file(), "Expected audio file to exist"
assert path.stat().st_size > _MIN_AUDIO_SIZE, "Expected longer audio data"
path.unlink()


@pytest.mark.skipif(API_KEYS.google is None, reason="Missing Google API Key.")
@pytest.mark.network()
def test_Google_speech_synthesis():
@@ -97,3 +136,30 @@ def test_Tiro_speech_synthesis():
assert path.is_file(), "Expected audio file to exist"
assert path.stat().st_size > _MIN_AUDIO_SIZE, "Expected longer audio data"
path.unlink()


@patch.dict(SERVICES, {"mock_service": MagicMock()})
@patch.dict(VOICES, {"Dora": {"service": "mock_service"}})
def test_keys_override_in_tts_to_file():
"""Test if keys_override is correctly passed into service.text_to_speech."""
_TEXT = "Test"
SERVICES["mock_service"].audio_formats = ["mp3"]
keys_override = Keys(
aws=AWSPollyKey(
aws_access_key_id="test",
aws_secret_access_key="test",
region_name="test",
)
)
opts = TTSOptions(text_format=TextFormats.TEXT, audio_format="mp3", voice="Dora")
tts_to_file(
_TEXT,
opts,
transcribe=False,
keys_override=keys_override,
)
SERVICES["mock_service"].text_to_speech.assert_called_once_with(
_TEXT,
opts,
keys_override,
)

0 comments on commit 25e2866

Please sign in to comment.