Skip to content

Commit

Permalink
[fix] Replace default generic model parser names with hardcoded defaults
Browse files Browse the repository at this point in the history
Right now we get errors if trying to use aliasing in AIConfig for model parsers, so we need to use an actual model instead. For now just doing this so we can use these aliases like 'Image2Text' when we share demo to Gradio
  • Loading branch information
Rossdan Craig [email protected] committed Jan 10, 2024
1 parent 5908342 commit 2a5df97
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 38 deletions.
57 changes: 23 additions & 34 deletions cookbooks/Gradio/huggingface.aiconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,23 @@
"metadata": {
"parameters": {},
"models": {
"TextGeneration": {
"model": "stevhliu/my_awesome_billsum_model",
"min_length": 10,
"max_length": 30
"AudioSpeechRecognition": {
"model": "openai/whisper-small"
},
"ImageToText": {
"Image2Text": {
"model": "Salesforce/blip-image-captioning-base"
},
"Text2Speech": {
"model": "suno/bark"
},
"Text2Image": {
"model": "runwayml/stable-diffusion-v1-5"
},
"TextGeneration": {
"model": "stevhliu/my_awesome_billsum_model",
"min_length": 10,
"max_length": 30
},
"TextSummarization": {
"model": "facebook/bart-large-cnn"
},
Expand All @@ -24,16 +30,13 @@
},
"default_model": "TextGeneration",
"model_parsers": {
"AudioSpeechRecognition": "HuggingFaceAutomaticSpeechRecognitionTransformer",
"Image2Text": "HuggingFaceImage2TextTransformer",
"Salesforce/blip-image-captioning-base": "HuggingFaceImage2TextTransformer",
"Text2Speech": "HuggingFaceText2SpeechTransformer",
"suno/bark": "HuggingFaceText2SpeechTransformer",
"Text2Image": "HuggingFaceText2ImageTransformer",
"TextGeneration": "HuggingFaceTextGenerationTransformer",
// "stevhliu/my_awesome_billsum_model": "HuggingFaceTextGenerationTransformer",
"TextSummarization": "HuggingFaceTextSummarizationTransformer",
"facebook/bart-large-cnn": "HuggingFaceTextSummarizationTransformer",
"TextTranslation": "HuggingFaceTextTranslationTransformer",
"translation_en_to_fr": "HuggingFaceTextTranslationTransformer"
"TextTranslation": "HuggingFaceTextTranslationTransformer"
}
},
"description": "The Tale of the Quick Brown Fox",
Expand All @@ -52,22 +55,14 @@
"parameters": {
"city": "New York"
}
},
"outputs": [
{
"output_type": "execute_result",
"execution_count": 0,
"data": "<pad> a sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden sudden",
"metadata": {}
}
]
}
},
{
"name": "translate_instruction",
"input": "Tell the tale of {{topic}}",
"metadata": {
"model": {
"name": "TextGeneration",
"name": "TextTranslation",
"settings": {
"min_length": "",
"max_new_tokens": 100
Expand All @@ -76,8 +71,7 @@
"parameters": {
"topic": "the quick brown fox"
}
},
"outputs": []
}
},
{
"name": "summarize_story",
Expand All @@ -88,17 +82,15 @@
"settings": {}
},
"parameters": {}
},
"outputs": []
}
},
{
"name": "generate_audio_title",
"input": "The Quick Brown Fox was admired by all the animals in the forest.",
"metadata": {
"model": "Text2Speech",
"parameters": {}
},
"outputs": []
}
},
{
"name": "generate_caption",
Expand All @@ -113,8 +105,7 @@
"metadata": {
"model": "Image2Text",
"parameters": {}
},
"outputs": []
}
},
{
"name": "openai_gen_itinerary",
Expand All @@ -124,8 +115,7 @@
"parameters": {
"order_by": "geographic location"
}
},
"outputs": []
}
},
{
"name": "Audio Speech Recognition",
Expand All @@ -138,10 +128,9 @@
]
},
"metadata": {
"model": "openai/whisper-small",
"model": "AudioSpeechRecognition",
"parameters": {}
},
"outputs": []
}
}
],
"$schema": "https://json.schemastore.org/aiconfig-1.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio

model_settings = self.get_model_settings(prompt, aiconfig)
[pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings)
model_name = aiconfig.get_model_name(prompt)

model_name: str = aiconfig.get_model_name(prompt)
# TODO: Clean this up after we allow people in the AIConfig UI to specify their
# own model name for HuggingFace tasks. This isn't great but it works for now
if (model_name == "TextTranslation"):
model_name = self._get_default_model_name()

if isinstance(model_name, str) and model_name not in self.pipelines:
device = self._get_device()
Expand Down Expand Up @@ -139,6 +144,9 @@ def get_output_text(
if isinstance(output_data, str):
return output_data
return ""

def _get_default_model_name(self) -> str:
return "openai/whisper-small"


def validate_attachment_type_is_audio(attachment: Attachment):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
print(pipeline_building_disclaimer_message)

model_name: str = aiconfig.get_model_name(prompt)
# TODO: Clean this up after we allow people in the AIConfig UI to specify their
# own model name for HuggingFace tasks. This isn't great but it works for now
if (model_name == "Text2Image"):
model_name = self._get_default_model_name()

# TODO (rossdanlm): Figure out a way to save model and re-use checkpoint
# Otherwise right now a lot of these models are taking 5 mins to load with 50
# num_inference_steps (default value). See here for more details:
Expand Down Expand Up @@ -364,6 +369,9 @@ def _get_device(self) -> str:
return "mps"
return "cpu"

def _get_default_model_name(self) -> str:
return "runwayml/stable-diffusion-v1-5"

def _refine_responses(
response_images: List[Image.Image],
nsfw_content_detected: List[bool],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
[pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings)

model_name: str = aiconfig.get_model_name(prompt)
# TODO: Clean this up after we allow people in the AIConfig UI to specify their
# own model name for HuggingFace tasks. This isn't great but it works for now
if (model_name == "Text2Speech"):
model_name = self._get_default_model_name()

if isinstance(model_name, str) and model_name not in self.synthesizers:
self.synthesizers[model_name] = pipeline("text-to-speech", model_name)
synthesizer = self.synthesizers[model_name]
Expand Down Expand Up @@ -229,3 +234,6 @@ def get_output_text(
elif isinstance(output.data, str):
return output.data
return ""

def _get_default_model_name(self) -> str:
return "suno/bark"
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,13 @@ async def run_inference(
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
completion_data["text_inputs"] = completion_data.pop("prompt", None)

model_name : str = aiconfig.get_model_name(prompt)
model_name: str = aiconfig.get_model_name(prompt)
# TODO: Clean this up after we allow people in the AIConfig UI to specify their
# own model name for HuggingFace tasks. This isn't great but it works for now
if (model_name == "TextGeneration"):
model_name = self._get_default_model_name()

if isinstance(model_name, str) and model_name not in self.generators:
print(f"Rossdan Loading model {prompt.metadata.model}")
print(f"Rossdan Loading model {model_name}")
self.generators[model_name] = pipeline('text-generation', model=model_name)
generator = self.generators[model_name]

Expand Down Expand Up @@ -305,3 +308,6 @@ def get_output_text(
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)
return ""

def _get_default_model_name(self) -> str:
return "stevhliu/my_awesome_billsum_model"
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
inputs = completion_data.pop("prompt", None)

model_name: str = aiconfig.get_model_name(prompt)
# TODO: Clean this up after we allow people in the AIConfig UI to specify their
# own model name for HuggingFace tasks. This isn't great but it works for now
if (model_name == "TextSummarization"):
model_name = self._get_default_model_name()

if isinstance(model_name, str) and model_name not in self.summarizers:
self.summarizers[model_name] = pipeline("summarization", model=model_name)
summarizer = self.summarizers[model_name]
Expand Down Expand Up @@ -303,3 +308,6 @@ def get_output_text(
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)
return ""

def _get_default_model_name(self) -> str:
return "facebook/bart-large-cnn"
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
inputs = completion_data.pop("prompt", None)

model_name: str = aiconfig.get_model_name(prompt)
# TODO: Clean this up after we allow people in the AIConfig UI to specify their
# own model name for HuggingFace tasks. This isn't great but it works for now
if (model_name == "TextTranslation"):
model_name = self._get_default_model_name()

if isinstance(model_name, str) and model_name not in self.translators:
self.translators[model_name] = pipeline(model_name)
translator = self.translators[model_name]
Expand Down Expand Up @@ -297,3 +302,6 @@ def get_output_text(
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)
return ""

def _get_default_model_name(self) -> str:
return "translation_en_to_fr"

0 comments on commit 2a5df97

Please sign in to comment.