diff --git a/src/cartesia/tts/socket_client.py b/src/cartesia/tts/socket_client.py index 9877606..ffcef4a 100644 --- a/src/cartesia/tts/socket_client.py +++ b/src/cartesia/tts/socket_client.py @@ -9,18 +9,16 @@ from ..core.pydantic_utilities import parse_obj_as from .types.cancel_context_request import CancelContextRequest from .types.generation_request import GenerationRequest -from .types.output_format import OutputFormat -from .types.tts_request_voice_specifier import TtsRequestVoiceSpecifier from .types.web_socket_request import WebSocketRequest from .types.web_socket_response import ( WebSocketResponse, WebSocketResponse_Chunk, WebSocketResponse_Done, WebSocketResponse_Error, + WebSocketResponse_PhonemeTimestamps, WebSocketResponse_Timestamps, ) from .types.web_socket_tts_output import WebSocketTtsOutput -from .utils.timeout_iterator import TimeoutIterator try: from websockets.sync.client import connect @@ -255,7 +253,9 @@ def close(self): def _convert_response( self, response_obj: typing.Union[ - WebSocketResponse_Chunk, WebSocketResponse_Timestamps + WebSocketResponse_Chunk, + WebSocketResponse_Timestamps, + WebSocketResponse_PhonemeTimestamps, ], include_context_id: bool, ) -> WebSocketTtsOutput: @@ -264,6 +264,8 @@ def _convert_response( out["audio"] = base64.b64decode(response_obj.data) elif isinstance(response_obj, WebSocketResponse_Timestamps): out["word_timestamps"] = response_obj.word_timestamps + elif isinstance(response_obj, WebSocketResponse_PhonemeTimestamps): + out["phoneme_timestamps"] = response_obj.phoneme_timestamps if include_context_id: out["context_id"] = response_obj.context_id