From 15ef3b42611ae3ef228e5d7edd8eeff72d6ce441 Mon Sep 17 00:00:00 2001 From: "Ankush Pala ankush@lastmileai.dev" <> Date: Fri, 5 Jan 2024 12:01:36 -0500 Subject: [PATCH 1/3] [extensions][py][hf] 1/n asr scaffolding Setting up the asr parser class --- .../automatic_speech_recognition.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py new file mode 100644 index 000000000..6ecc20c9a --- /dev/null +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py @@ -0,0 +1,65 @@ +from typing import Any, Coroutine, Dict, Optional, List, TYPE_CHECKING +from aiconfig import ParameterizedModelParser, InferenceOptions, AIConfig + +from aiconfig.schema import Prompt, Output +from transformers import Pipeline + +if TYPE_CHECKING: + from aiconfig import AIConfigRuntime +""" +Model Parser for HuggingFace ASR (Automatic Speech Recognition) models. +""" + + +class HuggingFaceAutomaticSpeechRecognition(ParameterizedModelParser): + def __init__(self): + """ + Returns: + HuggingFaceAutomaticSpeechRecognition + + Usage: + 1. Create a new model parser object with the model ID of the model to use. + parser = HuggingFaceAutomaticSpeechRecognition() + 2. Add the model parser to the registry. + config.register_model_parser(parser) + """ + super().__init__() + self.generators: dict[str, Pipeline] = {} + + def id(self) -> str: + """ + Returns an identifier for the Model Parser + """ + return "HuggingFaceAutomaticSpeechRecognition" + + async def serialize( + self, + prompt_name: str, + data: Any, + ai_config: "AIConfigRuntime", + parameters: Optional[Dict[str, Any]] = None, + **completion_params, + ) -> List[Prompt]: + """ + Defines how a prompt and model inference settings get serialized in the .aiconfig. + + Args: + prompt (str): The prompt to be serialized. + data (Any): Model-specific inference settings to be serialized. + ai_config (AIConfigRuntime): The AIConfig Runtime. + parameters (Dict[str, Any], optional): Model-specific parameters. Defaults to None. + + Returns: + str: Serialized representation of the prompt and inference settings. + """ + + async def deserialize( + self, + prompt: Prompt, + aiconfig: "AIConfig", + params: Optional[Dict[str, Any]] = {}, + ) -> Dict[str, Any]: + pass + + async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]: + pass From aa6bf1091c84f2fc267f69f6ebde9002295dfc44 Mon Sep 17 00:00:00 2001 From: "Ankush Pala ankush@lastmileai.dev" <> Date: Tue, 9 Jan 2024 11:35:47 -0500 Subject: [PATCH 2/3] [extensions][py][hf] 2/n ASR model parser impl Model Parser for the Automatic Speech Recognition task on huggingface. Decisions made while implementing: - manual impl to parse input attachments - - threw exceptions on every unexpected step. Not sure if this is the direction we want to go with this. - ## Testplan Created an mp3 file that says "hi". Used aiconfig to run asr on it. |Screenshot 2024-01-05 at 2 39 17 PM|Screenshot 2024-01-05 at 2 41 04 PM| | ------------- | ------------- | --- .../__init__.py | 4 +- .../automatic_speech_recognition.py | 105 ++++++++++++++++-- 2 files changed, 99 insertions(+), 10 deletions(-) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py index 7957c9541..ec786117e 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py @@ -1,8 +1,9 @@ from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor from .local_inference.text_generation import HuggingFaceTextGenerationTransformer -from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient +# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer from .local_inference.text_translation import HuggingFaceTextTranslationTransformer +from .local_inference.automatic_speech_recognition import HuggingFaceAutomaticSpeechRecognition # from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient @@ -11,6 +12,7 @@ "HuggingFaceTextGenerationTransformer", "HuggingFaceTextSummarizationTransformer", "HuggingFaceTextTranslationTransformer", + "HuggingFaceAutomaticSpeechRecognition", ] REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationClient"] __ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py index 6ecc20c9a..9160bcf12 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py @@ -1,8 +1,10 @@ -from typing import Any, Coroutine, Dict, Optional, List, TYPE_CHECKING -from aiconfig import ParameterizedModelParser, InferenceOptions, AIConfig +from typing import Any, Dict, Optional, List, TYPE_CHECKING +from aiconfig import ParameterizedModelParser, InferenceOptions +from aiconfig.callback import CallbackEvent +import torch +from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment -from aiconfig.schema import Prompt, Output -from transformers import Pipeline +from transformers import pipeline, Pipeline if TYPE_CHECKING: from aiconfig import AIConfigRuntime @@ -24,7 +26,7 @@ def __init__(self): config.register_model_parser(parser) """ super().__init__() - self.generators: dict[str, Pipeline] = {} + self.pipelines: dict[str, Pipeline] = {} def id(self) -> str: """ @@ -38,10 +40,10 @@ async def serialize( data: Any, ai_config: "AIConfigRuntime", parameters: Optional[Dict[str, Any]] = None, - **completion_params, ) -> List[Prompt]: """ Defines how a prompt and model inference settings get serialized in the .aiconfig. + Assume input in the form of input(s) being passed into an already constructed pipeline. Args: prompt (str): The prompt to be serialized. @@ -52,14 +54,99 @@ async def serialize( Returns: str: Serialized representation of the prompt and inference settings. """ + raise NotImplementedError("serialize is not implemented for HuggingFaceAutomaticSpeechRecognition") async def deserialize( self, prompt: Prompt, - aiconfig: "AIConfig", + aiconfig: "AIConfigRuntime", params: Optional[Dict[str, Any]] = {}, ) -> Dict[str, Any]: - pass + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) + + # Build Completion data + completion_params = self.get_model_settings(prompt, aiconfig) + + # ASR Pipeline supports input types of bytes, file path, and a dict containing raw sampled audio. Also supports multiple input + # For now, support multiple or single uri's as input + # TODO: Support or figure out if other input types are needed (base64, bytes), as well as the sampled audio dict + # See api docs for more info: + # - https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/pipelines/automatic_speech_recognition.py#L313-L317 + # - https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline + inputs = validate_and_retrieve_audio_from_attachments(prompt) + + completion_params["inputs"] = inputs + + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params})) + return completion_params async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]: - pass + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_run_start", + __name__, + {"prompt": prompt, "options": options, "parameters": parameters}, + ) + ) + model_name = aiconfig.get_model_name(prompt) + + if isinstance(model_name, str) and model_name not in self.pipelines: + device = self._get_device() + # Build a pipeline for the model. TODO: support other pipeline creation options. ie pipeline config, torch dtype, etc + self.pipelines[model_name] = pipeline(task="automatic-speech-recognition", model=model_name, device=device) + + asr_pipeline = self.pipelines[model_name] + completion_data = await self.deserialize(prompt, aiconfig, parameters) + + response = asr_pipeline(**completion_data) + + output = ExecuteResult(output_type="execute_result", data=response, metadata={}) + + prompt.outputs = [output] + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) + return prompt.outputs + + def _get_device(self) -> str: + if torch.cuda.is_available(): + return "cuda" + # Mps backend is not supported for all asr models. Seen when spinning up a default asr pipeline which uses facebook/wav2vec2-base-960h 55bb623 + return "cpu" + + def get_output_text(self, response: dict[str, Any]) -> str: + raise NotImplementedError("get_output_text is not implemented for HuggingFaceAutomaticSpeechRecognition") + + +def validate_attachment_type_is_audio(attachment: Attachment): + if not hasattr(attachment, "mime_type"): + raise ValueError(f"Attachment has no mime type. Specify the audio mimetype in the aiconfig") + + if not attachment.mime_type.startswith("audio/"): + raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected audio mimetype.") + + +def validate_and_retrieve_audio_from_attachments(prompt: Prompt) -> list[str]: + """ + Retrieves the audio uri's from each attachment in the prompt input. + + Throws an exception if + - attachment is not audio + - attachment data is not a uri + - no attachments are found + - operation fails for any reason + """ + + if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0: + raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an audio attachment to the prompt input.") + + audio_uris: list[str] = [] + + for i, attachment in enumerate(prompt.input.attachments): + validate_attachment_type_is_audio(attachment) + + if not isinstance(attachment.data, str): + # See todo above, but for now only support uri's + raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the audio attachment in prompt {prompt.name}.") + + audio_uris.append(attachment.data) + + return audio_uris From 0c4676298c7074a0920255536330702f16ee5d59 Mon Sep 17 00:00:00 2001 From: "Ankush Pala ankush@lastmileai.dev" <> Date: Tue, 9 Jan 2024 12:19:30 -0500 Subject: [PATCH 3/3] [extensions][py][hf] 3/n serialize --- .../automatic_speech_recognition.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py index 9160bcf12..927c6b8e7 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py @@ -54,7 +54,38 @@ async def serialize( Returns: str: Serialized representation of the prompt and inference settings. """ - raise NotImplementedError("serialize is not implemented for HuggingFaceAutomaticSpeechRecognition") + await ai_config.callback_manager.run_callbacks( + CallbackEvent( + "on_serialize_start", + __name__, + { + "prompt_name": prompt_name, + "data": data, + "parameters": parameters, + }, + ) + ) + + prompts = [] + + if not isinstance(data, dict): + raise ValueError("Invalid data type. Expected dict when serializing prompt data to aiconfig.") + if data.get("inputs", None) is None: + raise ValueError("Invalid data when serializing prompt to aiconfig. Input data must contain an inputs field.") + + prompt = Prompt( + **{ + "name": prompt_name, + "input": {"attachments": [{"data": data["inputs"]}]}, + "metadata": None, + "outputs": None, + } + ) + + prompts.append(prompt) + + await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts})) + return prompts async def deserialize( self,