-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathgradio_demo.py
178 lines (146 loc) · 6.4 KB
/
gradio_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"
import torch
from transformers import AutoTokenizer
from vis_corrector import Corrector
from models.utils import extract_boxes, find_matching_boxes, annotate
import sys
sys.path.append('path/to/mPLUG-Owl')
from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
from PIL import Image
from types import SimpleNamespace
import numpy as np
import cv2
import gradio as gr
import uuid
# ========================================
# Model Initialization
# ========================================
PROMPT_TEMPLATE = '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
Human: <image>
Human: {question}
AI: '''
print('Initializing Chat')
# initialize corrector
args_dict = {
'api_key': "sk-xxxxxxxxxxxxxxxx",
'api_base': "https://api.openai.com/v1",
'val_model_path': "Salesforce/blip2-flan-t5-xxl",
'qa2c_model_path': "khhuang/zerofec-qa2claim-t5-base",
'detector_config':"path/to/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
'detector_model_path':"path/to/GroundingDINO/weights/groundingdino_swint_ogc.pth",
'cache_dir': './cache_dir',
}
model_args = SimpleNamespace(**args_dict)
corrector = Corrector(model_args)
# initialize mplug-Owl
pretrained_ckpt = "MAGAer13/mplug-owl-llama-7b"
model = MplugOwlForConditionalGeneration.from_pretrained(
pretrained_ckpt,
torch_dtype=torch.bfloat16,
).to("cuda:1")
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
tokenizer = AutoTokenizer.from_pretrained(pretrained_ckpt)
processor = MplugOwlProcessor(image_processor, tokenizer)
print('Initialization Finished')
@torch.no_grad()
def my_model_function(image, question, box_threshold, area_threshold):
# create a temp dir to save uploaded imgs
temp_dir = "temp"
os.makedirs(temp_dir, exist_ok=True)
unique_filename = str(uuid.uuid4()) + ".png"
temp_file_path = os.path.join(temp_dir, unique_filename)
success = cv2.imwrite(temp_file_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
try:
output_text, output_image = model_predict(temp_file_path, question, box_threshold, area_threshold) # 你的模型预测函数
finally:
# remove temporary files.
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
return output_text, output_image
def get_owl_output(img_path, question):
prompts = [PROMPT_TEMPLATE.format(question=question)]
image_list = [img_path]
# get response
generate_kwargs = {
'do_sample': False,
'top_k': 5,
'max_length': 512
}
images = [Image.open(_) for _ in image_list]
inputs = processor(text=prompts, images=images, return_tensors='pt')
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
res = model.generate(**inputs, **generate_kwargs)
sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
return sentence
def model_predict(image_path, question, box_threshold, area_threshold):
temp_output_filepath = os.path.join("temp", str(uuid.uuid4()) + ".png")
os.makedirs("temp", exist_ok=True)
try:
owl_output = get_owl_output(image_path, question)
corrector_sample = {
'img_path': image_path,
'input_desc': owl_output,
'query': question,
'box_threshold': box_threshold,
'area_threshold': area_threshold
}
corrector_sample = corrector.correct(corrector_sample)
corrector_output = corrector_sample['output']
extracted_boxes = extract_boxes(corrector_output)
boxes, phrases = find_matching_boxes(extracted_boxes, corrector_sample['entity_info'])
output_image = annotate(image_path, boxes, phrases)
cv2.imwrite(temp_output_filepath, output_image)
output_image_pil = Image.open(temp_output_filepath)
output_text = f"mPLUG-Owl:\n{owl_output}\n\nCorrector:\n{corrector_output}"
return output_text, output_image_pil
finally:
if os.path.exists(temp_output_filepath):
os.remove(temp_output_filepath)
def create_multi_modal_demo():
with gr.Blocks() as instruct_demo:
with gr.Row():
with gr.Column():
img = gr.Image(label='Upload Image')
question = gr.Textbox(lines=2, label="Prompt")
with gr.Accordion(label='Detector parameters', open=True):
box_threshold = gr.Slider(minimum=0, maximum=1,
value=0.35, label="Box threshold")
area_threshold = gr.Slider(minimum=0, maximum=1,
value=0.02, label="Area threshold")
run_botton = gr.Button("Run")
with gr.Column():
output_text = gr.Textbox(lines=10, label="Output")
output_img = gr.Image(label="Output Image", type='pil')
inputs = [img, question, box_threshold, area_threshold]
outputs = [output_text, output_img]
examples = [
["./examples/case1.jpg", "How many people in the image?"],
["./examples/case2.jpg", "Is there any car in the image?"],
["./examples/case3.jpg", "Describe this image."],
]
gr.Examples(
examples=examples,
inputs=inputs,
outputs=outputs,
fn=my_model_function,
cache_examples=False,
run_on_click=True
)
run_botton.click(fn=my_model_function,
inputs=inputs, outputs=outputs)
return instruct_demo
description = """
# Woodpecker: Hallucination Correction for MLLMs🔧
**Note**: Due to network restrictions, it is recommended that the size of the uploaded image be less than **1M**.
Please refer to our [github](https://github.com/BradyFU/Woodpecker) for more details.
"""
with gr.Blocks(css="h1,p {text-align: center;}") as demo:
gr.Markdown(description)
with gr.TabItem("Multi-Modal Interaction"):
create_multi_modal_demo()
demo.queue(api_open=False).launch(share=True)