-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Renamed the file and made that G small of gradio
- Loading branch information
1 parent
77f929f
commit 37033ef
Showing
1 changed file
with
157 additions
and
0 deletions.
There are no files selected for viewing
157 changes: 157 additions & 0 deletions
157
recipes/quickstart/inference/local_inference/multi_modal_infer_gradio_UI.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import gradio as gr | ||
import torch | ||
import os | ||
from PIL import Image | ||
from accelerate import Accelerator | ||
from transformers import MllamaForConditionalGeneration, AutoProcessor | ||
import argparse # Import argparse | ||
|
||
# Parse the command line arguments | ||
parser = argparse.ArgumentParser(description="Run Gradio app with Hugging Face model") | ||
parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face authentication token") | ||
args = parser.parse_args() | ||
|
||
# Hugging Face token | ||
hf_token = args.hf_token | ||
|
||
# Initialize Accelerator | ||
accelerate = Accelerator() | ||
device = accelerate.device | ||
|
||
# Set memory management for PyTorch | ||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # or adjust size as needed | ||
|
||
# Model ID | ||
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" | ||
|
||
# Load model with the Hugging Face token | ||
model = MllamaForConditionalGeneration.from_pretrained( | ||
model_id, | ||
torch_dtype=torch.bfloat16, | ||
device_map=device, | ||
use_auth_token=hf_token # Pass the Hugging Face token here | ||
) | ||
|
||
# Load the processor | ||
processor = AutoProcessor.from_pretrained(model_id, use_auth_token=hf_token) | ||
|
||
# Visual theme | ||
visual_theme = gr.themes.Default() # Default, Soft or Monochrome | ||
|
||
# Constants | ||
MAX_OUTPUT_TOKENS = 2048 | ||
MAX_IMAGE_SIZE = (1120, 1120) | ||
|
||
# Function to process the image and generate a description | ||
def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history): | ||
# Initialize cleaned_output variable | ||
cleaned_output = "" | ||
|
||
if image is not None: | ||
# Resize image if necessary | ||
image = image.resize(MAX_IMAGE_SIZE) | ||
prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:" | ||
# Preprocess the image and prompt | ||
inputs = processor(image, prompt, return_tensors="pt").to(device) | ||
else: | ||
# Text-only input if no image is provided | ||
prompt = f"<|begin_of_text|>{user_prompt} Answer:" | ||
# Preprocess the prompt only (no image) | ||
inputs = processor(prompt, return_tensors="pt").to(device) | ||
|
||
# Generate output with model | ||
output = model.generate( | ||
**inputs, | ||
max_new_tokens=min(max_tokens, MAX_OUTPUT_TOKENS), | ||
temperature=temperature, | ||
top_k=top_k, | ||
top_p=top_p | ||
) | ||
|
||
# Decode the raw output | ||
raw_output = processor.decode(output[0]) | ||
|
||
# Clean up the output to remove system tokens | ||
cleaned_output = raw_output.replace("<|image|><|begin_of_text|>", "").strip().replace(" Answer:", "") | ||
|
||
# Ensure the prompt is not repeated in the output | ||
if cleaned_output.startswith(user_prompt): | ||
cleaned_output = cleaned_output[len(user_prompt):].strip() | ||
|
||
# Append the new conversation to the history | ||
history.append((user_prompt, cleaned_output)) | ||
|
||
return history | ||
|
||
|
||
# Function to clear the chat history | ||
def clear_chat(): | ||
return [] | ||
|
||
# Gradio Interface | ||
def gradio_interface(): | ||
with gr.Blocks(visual_theme) as demo: | ||
gr.HTML( | ||
""" | ||
<h1 style='text-align: center'> | ||
meta-llama/Llama-3.2-11B-Vision-Instruct | ||
</h1> | ||
""") | ||
with gr.Row(): | ||
# Left column with image and parameter inputs | ||
with gr.Column(scale=1): | ||
image_input = gr.Image( | ||
label="Image", | ||
type="pil", | ||
image_mode="RGB", | ||
height=512, # Set the height | ||
width=512 # Set the width | ||
) | ||
|
||
# Parameter sliders | ||
temperature = gr.Slider( | ||
label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1, interactive=True) | ||
top_k = gr.Slider( | ||
label="Top-k", minimum=1, maximum=100, value=50, step=1, interactive=True) | ||
top_p = gr.Slider( | ||
label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1, interactive=True) | ||
max_tokens = gr.Slider( | ||
label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50, interactive=True) | ||
|
||
# Right column with the chat interface | ||
with gr.Column(scale=2): | ||
chat_history = gr.Chatbot(label="Chat", height=512) | ||
|
||
# User input box for prompt | ||
user_prompt = gr.Textbox( | ||
show_label=False, | ||
container=False, | ||
placeholder="Enter your prompt", | ||
lines=2 | ||
) | ||
|
||
# Generate and Clear buttons | ||
with gr.Row(): | ||
generate_button = gr.Button("Generate") | ||
clear_button = gr.Button("Clear") | ||
|
||
# Define the action for the generate button | ||
generate_button.click( | ||
fn=describe_image, | ||
inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history], | ||
outputs=[chat_history] | ||
) | ||
|
||
# Define the action for the clear button | ||
clear_button.click( | ||
fn=clear_chat, | ||
inputs=[], | ||
outputs=[chat_history] | ||
) | ||
|
||
return demo | ||
|
||
# Launch the interface | ||
demo = gradio_interface() | ||
# demo.launch(server_name="0.0.0.0", server_port=12003) | ||
demo.launch() |