From 703d9eaaf3b6a954302bd20938a0c5d669226596 Mon Sep 17 00:00:00 2001 From: Noah Tye Date: Tue, 19 Nov 2024 17:25:51 -0800 Subject: [PATCH 1/8] Voice cloning and creation (#11) Adds method for new clone endpoint, and adds language param to voice creation method. --- cartesia/voices.py | 58 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/cartesia/voices.py b/cartesia/voices.py index 4f67399..cd3c5a7 100644 --- a/cartesia/voices.py +++ b/cartesia/voices.py @@ -58,29 +58,65 @@ def get(self, id: str) -> VoiceMetadata: return response.json() - def clone(self, filepath: Optional[str] = None, enhance: str = True) -> List[float]: + def clone( + self, + filepath: Optional[str] = None, + enhance: str = True, + mode: str = "clip", + language: str = "en", + name: Optional[str] = None, + description: Optional[str] = None, + transcript: Optional[str] = None, + ) -> Union[List[float], VoiceMetadata]: """Clone a voice from a clip. Args: filepath: The path to the clip file. enhance: Whether to enhance the clip before cloning the voice (highly recommended). Defaults to True. + mode: The mode to use for cloning. Either "similarity" or "stability". + language: The language code of the language spoken in the clip. Defaults to "en". + name: The name of the cloned voice. + description: The description of the cloned voice. + transcript: The transcript of the clip. Only used if mode is "similarity". Returns: The embedding of the cloned voice as a list of floats. """ if not filepath: raise ValueError("Filepath must be specified.") - url = f"{self._http_url()}/voices/clone/clip" + headers = self.headers.copy() + headers.pop("Content-Type", None) + with open(filepath, "rb") as file: files = {"clip": file} - files["enhance"] = str(enhance).lower() - headers = self.headers.copy() - headers.pop("Content-Type", None) - response = httpx.post(url, headers=headers, files=files, timeout=self.timeout) - if not response.is_success: - raise ValueError(f"Failed to clone voice from clip. Error: {response.text}") - - return response.json()["embedding"] + data = { + "enhance": str(enhance).lower(), + "mode": mode, + } + if mode == "clip": + url = f"{self._http_url()}/voices/clone/clip" + response = httpx.post( + url, headers=headers, files=files, data=data, timeout=self.timeout + ) + if not response.is_success: + raise ValueError(f"Failed to clone voice from clip. Error: {response.text}") + return response.json()["embedding"] + else: + data["name"] = name + data["description"] = description + data["language"] = language + if mode == "similarity" and transcript: + data["transcript"] = transcript + url = f"{self._http_url()}/voices/clone" + response = httpx.post( + url, headers=headers, files=files, data=data, timeout=self.timeout + ) + if not response.is_success: + raise ValueError( + f"Failed to clone voice. Status Code: {response.status_code}\n" + f"Error: {response.text}" + ) + return response.json() def create( self, @@ -88,6 +124,7 @@ def create( description: str, embedding: List[float], base_voice_id: Optional[str] = None, + language: str = "en", ) -> VoiceMetadata: """Create a new voice. @@ -108,6 +145,7 @@ def create( "description": description, "embedding": embedding, "base_voice_id": base_voice_id, + "language": language, }, timeout=self.timeout, ) From b706774594ca24742fef4b39baccb71ad4df2cb0 Mon Sep 17 00:00:00 2001 From: Noah Tye Date: Tue, 19 Nov 2024 18:21:13 -0800 Subject: [PATCH 2/8] [bumpversion] 1.2.0 (#12) --- README.md | 2 -- cartesia/_types.py | 33 --------------------------------- cartesia/tts.py | 9 --------- cartesia/version.py | 2 +- pyproject.toml | 2 +- tests/test_tts.py | 11 ----------- uv.lock | 6 +++--- 7 files changed, 5 insertions(+), 60 deletions(-) diff --git a/README.md b/README.md index 969bca7..bff4f03 100644 --- a/README.md +++ b/README.md @@ -629,8 +629,6 @@ display(audio) You can use the `client.tts.get_output_format` method to convert string-based output format names into the `output_format` dictionary which is expected by the `output_format` parameter. You can see the `OutputFormatMapping` class in `cartesia._types` for the currently supported output format names. You can also view the currently supported `output_format`s in our [API Reference](https://docs.cartesia.ai/reference/api-reference/rest/stream-speech-server-sent-events). -The previously used `output_format` strings are now deprecated and will be removed in v1.2.0. These are listed in the `DeprecatedOutputFormatMapping` class in `cartesia._types`. - ```python # Get the output format dictionary from string name output_format = client.tts.get_output_format("raw_pcm_f32le_44100") diff --git a/cartesia/_types.py b/cartesia/_types.py index 8eefff5..edb598a 100644 --- a/cartesia/_types.py +++ b/cartesia/_types.py @@ -27,39 +27,6 @@ def get_format(cls, format_name): raise ValueError(f"Unsupported format: {format_name}") -class DeprecatedOutputFormatMapping: - """Deprecated formats as of v1.0.1. These will be removed in v1.2.0. Use :class:`OutputFormatMapping` instead.""" - - _format_mapping = { - "fp32": {"container": "raw", "encoding": "pcm_f32le", "sample_rate": 44100}, - "pcm": {"container": "raw", "encoding": "pcm_s16le", "sample_rate": 44100}, - "fp32_8000": {"container": "raw", "encoding": "pcm_f32le", "sample_rate": 8000}, - "fp32_16000": {"container": "raw", "encoding": "pcm_f32le", "sample_rate": 16000}, - "fp32_22050": {"container": "raw", "encoding": "pcm_f32le", "sample_rate": 22050}, - "fp32_24000": {"container": "raw", "encoding": "pcm_f32le", "sample_rate": 24000}, - "fp32_44100": {"container": "raw", "encoding": "pcm_f32le", "sample_rate": 44100}, - "pcm_8000": {"container": "raw", "encoding": "pcm_s16le", "sample_rate": 8000}, - "pcm_16000": {"container": "raw", "encoding": "pcm_s16le", "sample_rate": 16000}, - "pcm_22050": {"container": "raw", "encoding": "pcm_s16le", "sample_rate": 22050}, - "pcm_24000": {"container": "raw", "encoding": "pcm_s16le", "sample_rate": 24000}, - "pcm_44100": {"container": "raw", "encoding": "pcm_s16le", "sample_rate": 44100}, - "mulaw_8000": {"container": "raw", "encoding": "pcm_mulaw", "sample_rate": 8000}, - "alaw_8000": {"container": "raw", "encoding": "pcm_alaw", "sample_rate": 8000}, - } - - @classmethod - @deprecated( - vdeprecated="1.0.1", - vremove="1.2.0", - reason="Old output format names are being deprecated in favor of names aligned with the Cartesia API. Use names from `OutputFormatMapping` instead.", - ) - def get_format_deprecated(cls, format_name): - if format_name in cls._format_mapping: - return cls._format_mapping[format_name] - else: - raise ValueError(f"Unsupported format: {format_name}") - - class VoiceMetadata(TypedDict): id: str name: str diff --git a/cartesia/tts.py b/cartesia/tts.py index be26c30..a596a81 100644 --- a/cartesia/tts.py +++ b/cartesia/tts.py @@ -4,7 +4,6 @@ from cartesia._sse import _SSE from cartesia._types import ( - DeprecatedOutputFormatMapping, OutputFormat, OutputFormatMapping, VoiceControls, @@ -86,10 +85,6 @@ def get_output_format(output_format_name: str) -> OutputFormat: """ if output_format_name in OutputFormatMapping._format_mapping: output_format_obj = OutputFormatMapping.get_format(output_format_name) - elif output_format_name in DeprecatedOutputFormatMapping._format_mapping: - output_format_obj = DeprecatedOutputFormatMapping.get_format_deprecated( - output_format_name - ) else: raise ValueError(f"Unsupported format: {output_format_name}") @@ -114,10 +109,6 @@ def get_sample_rate(output_format_name: str) -> int: """ if output_format_name in OutputFormatMapping._format_mapping: output_format_obj = OutputFormatMapping.get_format(output_format_name) - elif output_format_name in DeprecatedOutputFormatMapping._format_mapping: - output_format_obj = DeprecatedOutputFormatMapping.get_format_deprecated( - output_format_name - ) else: raise ValueError(f"Unsupported format: {output_format_name}") diff --git a/cartesia/version.py b/cartesia/version.py index 6849410..c68196d 100644 --- a/cartesia/version.py +++ b/cartesia/version.py @@ -1 +1 @@ -__version__ = "1.1.0" +__version__ = "1.2.0" diff --git a/pyproject.toml b/pyproject.toml index 8b3ed20..51dd2f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cartesia" -version = "1.1.0" +version = "1.2.0" description = "The official Python library for the Cartesia API." readme = "README.md" requires-python = ">=3.9" diff --git a/tests/test_tts.py b/tests/test_tts.py index 4b86a94..83af081 100644 --- a/tests/test_tts.py +++ b/tests/test_tts.py @@ -1093,17 +1093,6 @@ def test_output_formats(resources: _Resources, output_format_name: str): assert output_format["encoding"] is not None, "Output format encoding is None" assert output_format["sample_rate"] is not None, "Output format sample rate is None" - -@pytest.mark.parametrize("output_format_name", deprecated_output_format_names) -def test_deprecated_output_formats(resources: _Resources, output_format_name: str): - logger.info(f"Testing deprecated output format: {output_format_name}") - output_format = resources.client.tts.get_output_format(output_format_name) - assert isinstance(output_format, dict), "Output is not of type dict" - assert output_format["container"] is not None, "Output format container is None" - assert output_format["encoding"] is not None, "Output format encoding is None" - assert output_format["sample_rate"] is not None, "Output format sample rate is None" - - def test_invalid_output_format(resources: _Resources): logger.info("Testing invalid output format") with pytest.raises(ValueError): diff --git a/uv.lock b/uv.lock index d3bfb05..3bd320e 100644 --- a/uv.lock +++ b/uv.lock @@ -162,7 +162,7 @@ wheels = [ [[package]] name = "cartesia" -version = "1.1.0" +version = "1.2.0" source = { virtual = "." } dependencies = [ { name = "aiohttp" }, @@ -172,7 +172,7 @@ dependencies = [ { name = "websockets" }, ] -[package.dependency-groups] +[package.dev-dependencies] dev = [ { name = "isort" }, { name = "numpy" }, @@ -195,7 +195,7 @@ requires-dist = [ { name = "websockets", specifier = ">=13.1" }, ] -[package.metadata.dependency-groups] +[package.metadata.requires-dev] dev = [ { name = "isort", specifier = ">=5.13.2" }, { name = "numpy", specifier = ">=2.0.2" }, From e0f5d14edb26ec47182f7fa283473259280108bf Mon Sep 17 00:00:00 2001 From: Timothy Luong Date: Wed, 27 Nov 2024 10:53:47 +0900 Subject: [PATCH 3/8] Adding the ability to manually flush to segment generators. (#13) ## Overview We're adding the ability to manually flush through the Python Client. A flush in this context will be called through `flush()` and result in a generator function that, when called, iterates over all audio generated for transcripts submitted since the last time (if ever) a `flush()` was called. In other words, if you do the following: ``` ws = await client.tts.websocket() ctx = ws.context() await ctx.send(..., transcript=transcript_1, ...) await ctx.send(..., transcript=transcript_2, ...) receiver_1 = await ctx.flush() await ctx.send(..., transcript=transcript_3, ...) receiver_2 = await ctx.flush() ``` Then iterating over the receivers will get you the following audio: ``` async for output in receiver_1(): # Audio Chunks for Transcript 1 and 2 async for output in receiver_2(): # Audio Chunks for Transcript 3 ``` Previously Cartesia TTS was solely multiple inputs to a single output receiver, at least with regards to [continuations](https://docs.cartesia.ai/build-with-sonic/capability-guides/stream-inputs-using-continuations). This is because we perform smart splitting & merging on our end to optimize model performance. However, many of the existing TTS architectures have multiple inputs to multiple output generators, so we should support this flexibility to accommodate for integrating with other providers with as minimal friction as possible. If you call `flush()` it'll be because you've performed chunking yourself (ideally the transcript or transcript(s) submitted are a sentence). ## Implementation Currently this is implemented using a similar mechanism to how we wrapped multiplexing over a single websocket. We performed this by iterating over the receiver asynchronously in the background, and separating chunks into Async Queues mapped by their Context ID. In this case, we separate chunks into Lists of Async Queues, where their Flush ID represents the index of the List where the queue we're populating is. When `flush()` is called, we append a new Asynchronous Queue to the list for that Context ID. ![multiplexing_flushing_python](https://github.com/user-attachments/assets/0e85a433-ab02-4e3c-8ccc-368e458ede48) In a previous iteration of this PR I was doing this non-deterministically using Async Polling from the N queue and the N+1 queue. Instead we introduced a notion of `flush_done` event fired after manual flushes prior to incrementing the `flush_id`. This allows us to deterministically indicate that the generator function has completed. ## Testing Tested on a local API version - Prod already added the manual flush capabilities but my local branch has the changes that incorporate the `flush_done` event. Confirmed that given 3 transcripts, I can create the 3 generator functions and iterate over each for the distinct audio generations. Also added a unit test that creates 3 generators from 3 transcript + flushes and iterates over each of them. This currently passes in the staging deployment, this PR is blocked until the flush done emission is deployed in production. ## Blockers Can't be merged until the `flush_done` event changes are deployed to Production. --- cartesia/_async_websocket.py | 65 ++++++++-- cartesia/_constants.py | 1 + cartesia/_websocket.py | 5 +- cartesia/utils/tts.py | 4 + tests/test_tts.py | 239 +++++++++++------------------------ 5 files changed, 142 insertions(+), 172 deletions(-) diff --git a/cartesia/_async_websocket.py b/cartesia/_async_websocket.py index 553e07a..dde7818 100644 --- a/cartesia/_async_websocket.py +++ b/cartesia/_async_websocket.py @@ -6,7 +6,7 @@ import aiohttp -from cartesia._constants import DEFAULT_MODEL_ID, DEFAULT_VOICE_EMBEDDING +from cartesia._constants import DEFAULT_MODEL_ID, DEFAULT_OUTPUT_FORMAT, DEFAULT_VOICE_EMBEDDING from cartesia._types import OutputFormat, VoiceControls from cartesia._websocket import _WebSocket from cartesia.tts import TTS @@ -45,6 +45,7 @@ async def send( voice_embedding: Optional[List[float]] = None, context_id: Optional[str] = None, continue_: bool = False, + flush: bool = False, duration: Optional[int] = None, language: Optional[str] = None, add_timestamps: bool = False, @@ -60,6 +61,7 @@ async def send( voice_embedding: The embedding of the voice to use for generating audio. context_id: The context ID to use for the request. If not specified, a random context ID will be generated. continue_: Whether to continue the audio generation from the previous transcript or not. + flush: Whether to trigger a manual flush for the current context's generation. duration: The duration of the audio in seconds. language: The language code for the audio request. This can only be used with `model_id = sonic-multilingual`. add_timestamps: Whether to return word-level timestamps. @@ -71,7 +73,7 @@ async def send( """ if context_id is not None and context_id != self._context_id: raise ValueError("Context ID does not match the context ID of the current context.") - if continue_ and transcript == "": + if continue_ and transcript == "" and not flush: raise ValueError("Transcript cannot be empty when continue_ is True.") await self._websocket.connect() @@ -87,6 +89,7 @@ async def send( context_id=self._context_id, add_timestamps=add_timestamps, continue_=continue_, + flush=flush, _experimental_voice_controls=_experimental_voice_controls, ) @@ -100,12 +103,49 @@ async def no_more_inputs(self) -> None: await self.send( model_id=DEFAULT_MODEL_ID, transcript="", - output_format=TTS.get_output_format("raw_pcm_f32le_44100"), + output_format=TTS.get_output_format(DEFAULT_OUTPUT_FORMAT), voice_embedding=DEFAULT_VOICE_EMBEDDING, # Default voice embedding since it's a required input for now. context_id=self._context_id, continue_=False, ) + async def flush(self) -> Callable[[], AsyncGenerator[Dict[str, Any], None]]: + """Trigger a manual flush for the current context's generation. This method returns a generator that yields the audio prior to the flush.""" + await self.send( + model_id=DEFAULT_MODEL_ID, + transcript="", + output_format=TTS.get_output_format(DEFAULT_OUTPUT_FORMAT), + voice_embedding=DEFAULT_VOICE_EMBEDDING, # Default voice embedding since it's a required input for now. + context_id=self._context_id, + continue_=True, + flush=True, + ) + + # Save the old flush ID + flush_id = len(self._websocket._context_queues[self._context_id]) - 1 + + # Create a new Async Queue to store the responses for the new flush ID + self._websocket._context_queues[self._context_id].append(asyncio.Queue()) + + # Return the generator for the old flush ID + async def generator(): + try: + while True: + response = await self._websocket._get_message( + self._context_id, timeout=self.timeout, flush_id=flush_id + ) + if "error" in response: + raise RuntimeError(f"Error generating audio:\n{response['error']}") + if response.get("flush_done") or response["done"]: + break + yield self._websocket._convert_response(response, include_context_id=True) + except Exception as e: + if isinstance(e, asyncio.TimeoutError): + raise RuntimeError("Timeout while waiting for audio chunk") + raise RuntimeError(f"Failed to generate audio:\n{e}") + + return generator + async def receive(self) -> AsyncGenerator[Dict[str, Any], None]: """Receive the audio chunks from the WebSocket. This method is a generator that yields audio chunks. @@ -175,7 +215,7 @@ def __init__( self.timeout = timeout self._get_session = get_session self.websocket = None - self._context_queues: Dict[str, asyncio.Queue] = {} + self._context_queues: Dict[str, List[asyncio.Queue]] = {} self._processing_task: asyncio.Task = None def __del__(self): @@ -213,7 +253,7 @@ async def close(self): except asyncio.CancelledError: pass except TypeError as e: - # Ignore the error if the task is already cancelled + # Ignore the error if the task is already canceled. # For some reason we are getting None responses # TODO: This needs to be fixed - we need to think about why we are getting None responses. if "Received message 256:None" not in str(e): @@ -284,16 +324,23 @@ async def _process_responses(self): response = await self.websocket.receive_json() if response["context_id"]: context_id = response["context_id"] + flush_id = response.get("flush_id", -1) if context_id in self._context_queues: - await self._context_queues[context_id].put(response) + await self._context_queues[context_id][flush_id].put(response) except Exception as e: self._error = e raise e - async def _get_message(self, context_id: str, timeout: float) -> Dict[str, Any]: + async def _get_message( + self, context_id: str, timeout: float, flush_id: Optional[int] = -1 + ) -> Dict[str, Any]: if context_id not in self._context_queues: raise ValueError(f"Context ID {context_id} not found.") - return await asyncio.wait_for(self._context_queues[context_id].get(), timeout=timeout) + if len(self._context_queues[context_id]) <= flush_id: + raise ValueError(f"Flush ID {flush_id} not found for context ID {context_id}.") + return await asyncio.wait_for( + self._context_queues[context_id][flush_id].get(), timeout=timeout + ) def _remove_context(self, context_id: str): if context_id in self._context_queues: @@ -309,5 +356,5 @@ def context(self, context_id: Optional[str] = None) -> _AsyncTTSContext: if context_id is None: context_id = str(uuid.uuid4()) if context_id not in self._context_queues: - self._context_queues[context_id] = asyncio.Queue() + self._context_queues[context_id] = [asyncio.Queue()] return _AsyncTTSContext(context_id, self, self.timeout) diff --git a/cartesia/_constants.py b/cartesia/_constants.py index 0eb5033..1717a95 100644 --- a/cartesia/_constants.py +++ b/cartesia/_constants.py @@ -2,6 +2,7 @@ MULTILINGUAL_MODEL_ID = "sonic-multilingual" # latest multilingual model DEFAULT_BASE_URL = "api.cartesia.ai" DEFAULT_CARTESIA_VERSION = "2024-06-10" # latest version +DEFAULT_OUTPUT_FORMAT = "raw_pcm_f32le_44100" DEFAULT_TIMEOUT = 30 # seconds DEFAULT_NUM_CONNECTIONS = 10 # connections per client DEFAULT_VOICE_EMBEDDING = [1.0] * 192 diff --git a/cartesia/_websocket.py b/cartesia/_websocket.py index e8a4460..85cba18 100644 --- a/cartesia/_websocket.py +++ b/cartesia/_websocket.py @@ -239,7 +239,7 @@ def close(self): self._contexts.clear() def _convert_response( - self, response: Dict[str, any], include_context_id: bool + self, response: Dict[str, any], include_context_id: bool, include_flush_id: bool = False ) -> Dict[str, Any]: out = {} if response["type"] == EventType.AUDIO: @@ -250,6 +250,9 @@ def _convert_response( if include_context_id: out["context_id"] = response["context_id"] + if include_flush_id and "flush_id" in response: + out["flush_id"] = response["flush_id"] + return out def send( diff --git a/cartesia/utils/tts.py b/cartesia/utils/tts.py index 0b1d1c2..c27f775 100644 --- a/cartesia/utils/tts.py +++ b/cartesia/utils/tts.py @@ -37,6 +37,7 @@ def _construct_tts_request( add_timestamps: bool = False, context_id: Optional[str] = None, continue_: bool = False, + flush: bool = False, _experimental_voice_controls: Optional[VoiceControls] = None, ): tts_request = { @@ -71,4 +72,7 @@ def _construct_tts_request( if continue_: tts_request["continue"] = continue_ + if flush: + tts_request["flush"] = flush + return tts_request diff --git a/tests/test_tts.py b/tests/test_tts.py index 83af081..f9ea1ff 100644 --- a/tests/test_tts.py +++ b/tests/test_tts.py @@ -23,6 +23,11 @@ RESOURCES_DIR = os.path.join(THISDIR, "resources") DEFAULT_MODEL_ID = "sonic-english" # latest default model +DEFAULT_OUTPUT_FORMAT = { + "container": "raw", + "encoding": "pcm_f32le", + "sample_rate": 44100, +} MULTILINGUAL_MODEL_ID = "sonic-multilingual" # latest multilingual model SAMPLE_VOICE = "Newsman" SAMPLE_VOICE_ID = "d46abd1d-2d02-43e8-819f-51fb652c1c61" @@ -165,11 +170,7 @@ def test_sse_send( output_generate = client.tts.sse( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=stream, model_id=DEFAULT_MODEL_ID, _experimental_voice_controls=_experimental_voice_controls, @@ -191,11 +192,7 @@ def test_sse_send_with_model_id(resources: _Resources, stream: bool): output_generate = client.tts.sse( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=stream, model_id=DEFAULT_MODEL_ID, ) @@ -231,18 +228,12 @@ async def send_sse_request(client, transcript, voice_id, output_format, model_id "Hello, world! I'm generating audio on Cartesia. Hello, world! I'm generating audio on Cartesia.", ] - output_format = { - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100 - } - tasks = [ send_sse_request( client, transcript, SAMPLE_VOICE_ID, - output_format, + DEFAULT_OUTPUT_FORMAT, DEFAULT_MODEL_ID, num ) for num, transcript in enumerate(transcripts) @@ -261,11 +252,7 @@ def test_sse_send_with_voice_id_and_embedding(resources: _Resources): transcript=transcript, voice_id=SAMPLE_VOICE_ID, voice_embedding=embedding, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, ) @@ -291,11 +278,7 @@ def test_websocket_send( output_generate = ws.send( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=stream, model_id=DEFAULT_MODEL_ID, context_id=context_id, @@ -322,11 +305,7 @@ def test_websocket_send_timestamps(resources: _Resources, stream: bool): output_generate = ws.send( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=stream, model_id=DEFAULT_MODEL_ID, context_id=context_id, @@ -360,11 +339,7 @@ def test_sse_send_context_manager( output_generate = client.tts.sse( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, _experimental_voice_controls=_experimental_voice_controls, @@ -385,11 +360,7 @@ def test_sse_send_context_manager_with_err(): client.tts.sse( transcript=transcript, voice_id="", - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, ) # should throw err because voice_id is "" @@ -407,11 +378,7 @@ def test_websocket_send_context_manager(resources: _Resources): output_generate = ws.send( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, ) @@ -432,11 +399,7 @@ def test_websocket_send_context_manage_err(resources: _Resources): ws.send( transcript=transcript, voice_id="", - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, ) # should throw err because voice_id is "" @@ -461,11 +424,7 @@ async def test_async_sse_send( output = await async_client.tts.sse( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, _experimental_voice_controls=_experimental_voice_controls, @@ -497,11 +456,7 @@ async def test_async_websocket_send( output = await ws.send( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, context_id=context_id, @@ -529,11 +484,7 @@ async def test_async_websocket_send_timestamps(resources: _Resources): output = await ws.send( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, context_id=context_id, @@ -563,11 +514,7 @@ async def test_async_sse_send_context_manager(resources: _Resources): output_generate = await async_client.tts.sse( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, ) @@ -588,11 +535,7 @@ async def test_async_sse_send_context_manager_with_err(): await async_client.tts.sse( transcript=transcript, voice_id="", - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, ) # should throw err because voice_id is "" @@ -611,11 +554,7 @@ async def test_async_websocket_send_context_manager(): output_generate = await ws.send( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, ) @@ -638,11 +577,7 @@ def test_sse_send_multilingual(resources: _Resources, stream: bool, language: st output_generate = client.tts.sse( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=stream, model_id=MULTILINGUAL_MODEL_ID, language=language, @@ -668,11 +603,7 @@ def test_websocket_send_multilingual( output_generate = ws.send( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=stream, model_id=MULTILINGUAL_MODEL_ID, language=language, @@ -707,11 +638,7 @@ def test_sync_continuation_websocket_context_send(): model_id=DEFAULT_MODEL_ID, transcript=chunk_generator(transcripts), voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, ) for out in output_generate: assert out.keys() == {"audio", "context_id"} @@ -730,11 +657,7 @@ def test_sync_context_send_timestamps(resources: _Resources): output_generate = ctx.send( transcript=chunk_generator(transcripts), voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, model_id=DEFAULT_MODEL_ID, add_timestamps=True, ) @@ -763,11 +686,7 @@ async def test_continuation_websocket_context_send(): model_id=DEFAULT_MODEL_ID, transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, continue_=True, ) @@ -803,11 +722,7 @@ async def test_continuation_websocket_context_send_incorrect_transcript(): model_id=DEFAULT_MODEL_ID, transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, continue_=True, ) @@ -839,11 +754,7 @@ async def test_continuation_websocket_context_send_incorrect_voice_id(): model_id=DEFAULT_MODEL_ID, transcript=transcript, voice_id="", # voice_id is empty - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, continue_=True, ) @@ -910,11 +821,7 @@ async def test_continuation_websocket_context_send_incorrect_model_id(): model_id="", # model_id is empty transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, continue_=True, ) async for _ in ctx.receive(): @@ -943,11 +850,7 @@ async def test_continuation_websocket_context_send_incorrect_context_id(): transcript=transcript, voice_id=SAMPLE_VOICE_ID, context_id="sad-monkeys-fly", # context_id is different - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, continue_=True, ) @@ -977,11 +880,7 @@ async def test_continuation_websocket_context_twice_on_same_context(): model_id=DEFAULT_MODEL_ID, transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, continue_=True, ) @@ -991,11 +890,7 @@ async def test_continuation_websocket_context_twice_on_same_context(): model_id=DEFAULT_MODEL_ID, transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, continue_=True, ) @@ -1009,6 +904,46 @@ async def test_continuation_websocket_context_twice_on_same_context(): await async_client.close() +@pytest.mark.asyncio +async def test_continuation_websocket_context_send_flush(): + logger.info("Testing async continuation WebSocket context send with flush") + async_client = create_async_client() + ws = await async_client.tts.websocket() + context_id = str(uuid.uuid4()) + try: + ctx = ws.context(context_id) + transcripts = [ + "Hello, world!", + "My name is Cartesia.", + "I am a text-to-speech API.", + ] + receivers = [] + for transcript in transcripts: + await ctx.send( + model_id=DEFAULT_MODEL_ID, + transcript=transcript, + voice_id=SAMPLE_VOICE_ID, + output_format=DEFAULT_OUTPUT_FORMAT, + continue_=True, + ) + new_receiver = await ctx.flush() + receivers.append(new_receiver) + await ctx.no_more_inputs() + + for receiver in receivers: + async for out in receiver(): + if out.get("audio"): + assert out.keys() == {"audio", "context_id"} + assert isinstance(out["audio"], bytes) + elif out.get("flush_done"): + assert out.keys() == {"flush_done", "flush_id"} + else: + assert False, f"Received unexpected message: {out}" + finally: + await ws.close() + await async_client.close() + + async def context_runner(ws, transcripts): ctx = ws.context() @@ -1019,11 +954,7 @@ async def context_runner(ws, transcripts): model_id=DEFAULT_MODEL_ID, transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, continue_=True, ) @@ -1110,11 +1041,7 @@ def test_websocket_send_with_custom_url(): output_generate = ws.send( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, ) @@ -1135,11 +1062,7 @@ def test_sse_send_with_custom_url(): output_generate = client.tts.sse( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, ) @@ -1161,11 +1084,7 @@ def test_sse_send_with_incorrect_url(): client.tts.sse( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=False, model_id=DEFAULT_MODEL_ID, ) @@ -1187,11 +1106,7 @@ def test_websocket_send_with_incorrect_url(): ws.send( transcript=transcript, voice_id=SAMPLE_VOICE_ID, - output_format={ - "container": "raw", - "encoding": "pcm_f32le", - "sample_rate": 44100, - }, + output_format=DEFAULT_OUTPUT_FORMAT, stream=True, model_id=DEFAULT_MODEL_ID, ) From 1ae25e96822be887073eb0ead440f6dc21861f4c Mon Sep 17 00:00:00 2001 From: Timothy Luong Date: Wed, 27 Nov 2024 23:22:34 +0900 Subject: [PATCH 4/8] [bumpversion] 1.3.0 (#14) Bumping client to `1.3.0`. Primary change: - [Adding manual flushing for multiple generators](https://github.com/cartesia-ai/cartesia-python/pull/13) --- cartesia/version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cartesia/version.py b/cartesia/version.py index c68196d..67bc602 100644 --- a/cartesia/version.py +++ b/cartesia/version.py @@ -1 +1 @@ -__version__ = "1.2.0" +__version__ = "1.3.0" diff --git a/pyproject.toml b/pyproject.toml index 51dd2f4..433e9d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cartesia" -version = "1.2.0" +version = "1.3.0" description = "The official Python library for the Cartesia API." readme = "README.md" requires-python = ">=3.9" From d9d26eeb0c812f7c9bc6494d26703dd153bc4323 Mon Sep 17 00:00:00 2001 From: Kabir Goel Date: Mon, 2 Dec 2024 10:48:00 -0800 Subject: [PATCH 5/8] Relax version requirements --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 433e9d5..d0f4365 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "httpx>=0.27.2", "iterators>=0.2.0", "requests>=2.32.3", - "websockets>=13.1", + "websockets>=10.4", ] [tool.uv] From 8613c3564bf4e9aa595eeec27b2488ab8543583d Mon Sep 17 00:00:00 2001 From: Kabir Goel Date: Mon, 2 Dec 2024 11:01:08 -0800 Subject: [PATCH 6/8] Use hatchling for building Setuptools is broken (https://github.com/astral-sh/uv/issues/9513) --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index d0f4365..682e8f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,10 @@ dependencies = [ "websockets>=10.4", ] +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + [tool.uv] dev-dependencies = [ "isort>=5.13.2", From 427d82fddfe2037a146a2f046997e5bfb2822505 Mon Sep 17 00:00:00 2001 From: Kabir Goel Date: Mon, 2 Dec 2024 15:29:48 -0800 Subject: [PATCH 7/8] Relax dependency ranges --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 682e8f5..24dbb7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,9 +6,9 @@ readme = "README.md" requires-python = ">=3.9" dependencies = [ "aiohttp>=3.10.10", - "httpx>=0.27.2", + "httpx>=0.27.0", "iterators>=0.2.0", - "requests>=2.32.3", + "requests>=2.31.0", "websockets>=10.4", ] From 904be462f15ccfade8737b7a896c9a8087fb4820 Mon Sep 17 00:00:00 2001 From: Kabir Goel Date: Mon, 2 Dec 2024 15:30:32 -0800 Subject: [PATCH 8/8] [bumpversion] 1.3.1 --- cartesia/version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cartesia/version.py b/cartesia/version.py index 67bc602..9c73af2 100644 --- a/cartesia/version.py +++ b/cartesia/version.py @@ -1 +1 @@ -__version__ = "1.3.0" +__version__ = "1.3.1" diff --git a/pyproject.toml b/pyproject.toml index 24dbb7b..42eabe8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cartesia" -version = "1.3.0" +version = "1.3.1" description = "The official Python library for the Cartesia API." readme = "README.md" requires-python = ">=3.9"