Skip to content

Commit

Permalink
Add query string formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mattmazzola committed Feb 14, 2024
1 parent ead3afe commit 8edab4e
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
from io import BytesIO
from pathlib import Path
import re

import numpy as np
import requests
Expand Down Expand Up @@ -75,6 +76,18 @@ def __init__(self, model_path, model_base, conv_mode, temperature, top_p, num_be

def get_response(self, user_prompt: str, image_path: str, decoded_image: Image.Image, problem: dict):
qs = user_prompt
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in qs:
if self.model.config.mm_use_im_start_end:
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
else:
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
else:
if self.model.config.mm_use_im_start_end:
qs = image_token_se + "\n" + qs
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
Expand All @@ -87,10 +100,10 @@ def get_response(self, user_prompt: str, image_path: str, decoded_image: Image.I
local_decoded_image.save(local_decoded_image_path)
decoded_image_path = output_folder_path / f"hf_decode_{problem['pid']}.png"
decoded_image.save(decoded_image_path)
images = [local_decoded_image]
logging.debug(f"Saved local decoded image to {local_decoded_image_path}")
logging.debug(f"Saved HF decoded image to {decoded_image_path}")
# decoded_images = [decoded_image]
# images = [local_decoded_image]
images = [decoded_image.convert('RGB')]
images_tensor = process_images(images, self.image_processor, self.model.config).to(
self.model.device, dtype=torch.float16
)
Expand All @@ -106,7 +119,7 @@ def get_response(self, user_prompt: str, image_path: str, decoded_image: Image.I
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
# images=images_tensor,
images=images_tensor,
do_sample=True if self.temperature > 0 else False,
temperature=self.temperature,
top_p=self.top_p,
Expand Down

0 comments on commit 8edab4e

Please sign in to comment.