Skip to content

Commit

Permalink
[AIC-py] hf image2text parser
Browse files Browse the repository at this point in the history
test

patch #816

![pic](https://github.com/lastmile-ai/aiconfig/assets/148090348/d5cc26b3-6cb7-4331-af8a-92fd8c4e2471)

python extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/run_hf_example.py extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/hf_local_example.aiconfig.json


-> "red fox in the woods"
  • Loading branch information
jonathanlastmileai committed Jan 9, 2024
1 parent 53fbb69 commit 4e7045b
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from .local_inference.image_2_text import HuggingFaceImage2TextTransformer
from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor
from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer
from .local_inference.text_generation import HuggingFaceTextGenerationTransformer
from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer

from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser

# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient

LOCAL_INFERENCE_CLASSES = [
"HuggingFaceText2ImageDiffusor",
"HuggingFaceTextGenerationTransformer",
"HuggingFaceTextSummarizationTransformer",
"HuggingFaceTextTranslationTransformer",
"HuggingFaceText2SpeechTransformer",
"HuggingFaceAutomaticSpeechRecognition",
"HuggingFaceImage2TextTransformer",
]
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationParser"]
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Any, Dict, Optional, List, TYPE_CHECKING
from aiconfig import ParameterizedModelParser, InferenceOptions
from aiconfig.callback import CallbackEvent
import torch
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment

from transformers import pipeline, Pipeline

if TYPE_CHECKING:
from aiconfig import AIConfigRuntime


class HuggingFaceImage2TextTransformer(ParameterizedModelParser):
def __init__(self):
"""
Returns:
HuggingFaceImage2TextTransformer
Usage:
1. Create a new model parser object with the model ID of the model to use.
parser = HuggingFaceImage2TextTransformer()
2. Add the model parser to the registry.
config.register_model_parser(parser)
"""
super().__init__()
self.pipelines: dict[str, Pipeline] = {}

def id(self) -> str:
"""
Returns an identifier for the Model Parser
"""
return "HuggingFaceImage2TextTransformer"

async def serialize(
self,
prompt_name: str,
data: Any,
ai_config: "AIConfigRuntime",
parameters: Optional[Dict[str, Any]] = None,
) -> List[Prompt]:
"""
Defines how a prompt and model inference settings get serialized in the .aiconfig.
Assume input in the form of input(s) being passed into an already constructed pipeline.
Args:
prompt (str): The prompt to be serialized.
data (Any): Model-specific inference settings to be serialized.
ai_config (AIConfigRuntime): The AIConfig Runtime.
parameters (Dict[str, Any], optional): Model-specific parameters. Defaults to None.
Returns:
str: Serialized representation of the prompt and inference settings.
"""
raise NotImplementedError("serialize is not implemented for HuggingFaceImage2TextTransformer")

async def deserialize(
self,
prompt: Prompt,
aiconfig: "AIConfigRuntime",
params: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))

# Build Completion data
completion_params = self.get_model_settings(prompt, aiconfig)

inputs = validate_and_retrieve_image_from_attachments(prompt)

completion_params["inputs"] = inputs

await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params}))
return completion_params

async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_run_start",
__name__,
{"prompt": prompt, "options": options, "parameters": parameters},
)
)
model_name = aiconfig.get_model_name(prompt)

self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name)

captioner = self.pipelines[model_name]
completion_data = await self.deserialize(prompt, aiconfig, parameters)
print(f"{completion_data=}")
inputs = completion_data.pop("inputs")
model = completion_data.pop("model")
response = captioner(inputs, **completion_data)

output = ExecuteResult(output_type="execute_result", data=response, metadata={})

prompt.outputs = [output]
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}))
return prompt.outputs

def get_output_text(self, response: dict[str, Any]) -> str:
raise NotImplementedError("get_output_text is not implemented for HuggingFaceImage2TextTransformer")


def validate_attachment_type_is_image(attachment: Attachment):
if not hasattr(attachment, "mime_type"):
raise ValueError(f"Attachment has no mime type. Specify the image mimetype in the aiconfig")

if not attachment.mime_type.startswith("image/"):
raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected image mimetype.")


def validate_and_retrieve_image_from_attachments(prompt: Prompt) -> list[str]:
"""
Retrieves the image uri's from each attachment in the prompt input.
Throws an exception if
- attachment is not image
- attachment data is not a uri
- no attachments are found
- operation fails for any reason
"""

if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0:
raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an image attachment to the prompt input.")

image_uris: list[str] = []

for i, attachment in enumerate(prompt.input.attachments):
validate_attachment_type_is_image(attachment)

if not isinstance(attachment.data, str):
# See todo above, but for now only support uri's
raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the image attachment in prompt {prompt.name}.")

image_uris.append(attachment.data)

return image_uris

0 comments on commit 4e7045b

Please sign in to comment.