Skip to content

Commit

Permalink
Websocket TTS.
Browse files Browse the repository at this point in the history
  • Loading branch information
eyw520 committed Feb 6, 2025
1 parent 66f2e02 commit 75b32c5
Show file tree
Hide file tree
Showing 12 changed files with 2,073 additions and 957 deletions.
5 changes: 4 additions & 1 deletion .fernignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ requirements.txt
# Custom Client Modifications

src/cartesia/client.py
src/cartesia/tts/_async_websocket.py
src/cartesia/tts/_websocket.py
src/cartesia/tts/socket_client.py
src/cartesia/tts/utils
src/cartesia/tts/types/web_socket_tts_output.py
src/cartesia/tts/utils/timeout_iterator.py
tests/custom

# Pending changes to README.md generation
README.md
62 changes: 58 additions & 4 deletions src/cartesia/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# This file was auto-generated by Fern from our API Definition.

import asyncio
import typing
from types import TracebackType
from typing import Union

import aiohttp
import httpx

from .base_client import AsyncBaseCartesia, BaseCartesia
Expand Down Expand Up @@ -110,9 +112,10 @@ def __init__(
base_url: typing.Optional[str] = None,
environment: CartesiaEnvironment = CartesiaEnvironment.PRODUCTION,
api_key: str,
timeout: typing.Optional[float] = None,
follow_redirects: typing.Optional[bool] = True,
httpx_client: typing.Optional[httpx.AsyncClient] = None,
timeout: typing.Optional[float] = 30,
max_num_connections: typing.Optional[int] = 10,
):
super().__init__(
base_url=base_url,
Expand All @@ -122,4 +125,55 @@ def __init__(
follow_redirects=follow_redirects,
httpx_client=httpx_client,
)
self.tts = AsyncTtsClientWithWebsocket(client_wrapper=self._client_wrapper)
self.timeout = timeout
self._session = None
self._loop = None
self.max_num_connections = max_num_connections
self.tts = AsyncTtsClientWithWebsocket(
client_wrapper=self._client_wrapper,
get_session=self._get_session
)

async def _get_session(self):
"""
This method is used to get a session for the client.
"""
current_loop = asyncio.get_event_loop()
if self._loop is not current_loop:
await self.close()
if self._session is None or self._session.closed:
timeout = aiohttp.ClientTimeout(total=self.timeout)
connector = aiohttp.TCPConnector(limit=self.max_num_connections)
self._session = aiohttp.ClientSession(timeout=timeout, connector=connector)
self._loop = current_loop
return self._session

async def close(self):
"""This method closes the session.
It is *strongly* recommended to call this method when you are done using the client.
"""
if self._session is not None and not self._session.closed:
await self._session.close()

def __del__(self):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

if loop is None:
asyncio.run(self.close())
elif loop.is_running():
loop.create_task(self.close())

async def __aenter__(self):
return self

async def __aexit__(
self,
exc_type: Union[type, None],
exc: Union[BaseException, None],
exc_tb: Union[TracebackType, None],
):
await self.close()
Loading

0 comments on commit 75b32c5

Please sign in to comment.