Skip to content

Commit

Permalink
Merge pull request #1519 from vespa-engine/thomasht86/add-gemini-infe…
Browse files Browse the repository at this point in the history
…rence

(colpalidemo) add gemini inference
  • Loading branch information
thomasht86 authored Oct 22, 2024
2 parents 2f8a5ba + 001d219 commit 2f0e7d0
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 39 deletions.
3 changes: 2 additions & 1 deletion visual-retrieval-colpali/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ __pycache__/
.env
template/
*.json
output/
output/
pdfs/
39 changes: 9 additions & 30 deletions visual-retrieval-colpali/frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,10 @@ def Search(request, search_results=[]):
)


def LoadingMessage():
def LoadingMessage(display_text="Retrieving search results"):
return Div(
Lucide(icon="loader-circle", cls="size-5 mr-1.5 animate-spin"),
Span("Retrieving search results", cls="text-base text-center"),
Span(display_text, cls="text-base text-center"),
cls="p-10 text-muted-foreground flex items-center justify-center",
id="loading-indicator",
)
Expand Down Expand Up @@ -364,44 +364,23 @@ def SearchResult(results: list, query_id: Optional[str] = None):
)


def ChatResult():
def ChatResult(query_id: str, query: str):
return Div(
Div("Chat", cls="text-xl font-semibold p-3"),
Div(
Div(
Div(
"Hello! How can I assist you today?",
LoadingMessage(display_text="Waiting for response..."),
cls="bg-muted/80 dark:bg-muted/40 text-black dark:text-white p-2 rounded-md",
hx_ext="sse",
sse_connect=f"/get-message?query_id={query_id}&query={quote_plus(query)}",
sse_swap="message",
sse_close="close",
hx_swap="innerHTML",
),
Div(
"Can you show me an example of chat layout?",
cls="question-message p-2 rounded-md self-end",
),
Div(
"Sure! Here's an example with sample messages.",
cls="bg-muted/80 dark:bg-muted/40 text-black dark:text-white p-2 rounded-md",
),
Div("Awesome! Thanks!", cls="question-message p-2 rounded-md self-end"),
Div(
"You're welcome!",
cls="bg-muted/80 dark:bg-muted/40 text-black dark:text-white p-2 rounded-md",
),
Div(
"What else can you do?",
cls="question-message p-2 rounded-md self-end",
),
Div(
"I can help with various tasks. Just ask!",
cls="bg-muted/80 dark:bg-muted/40 text-black dark:text-white p-2 rounded-md",
),
cls="flex flex-col gap-2 text-sm",
),
id="chat-messages",
cls="overflow-auto min-h-0 grid items-end px-3",
),
Div(
Input(placeholder="Type your message here..."),
cls="bg-muted/80 dark:bg-muted/40 p-3 border-t",
),
cls="h-full grid grid-rows-[auto_1fr_auto] min-h-0 gap-3",
)
54 changes: 53 additions & 1 deletion visual-retrieval-colpali/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
SimMapButtonReady,
)
from frontend.layout import Layout
import google.generativeai as genai
from PIL import Image
import io
import base64

highlight_js_theme_link = Link(id="highlight-theme", rel="stylesheet", href="")
highlight_js_theme = Script(src="/static/js/highlightjs-theme.js")
Expand All @@ -44,6 +48,7 @@
overlayscrollbars_js = Script(
src="https://cdnjs.cloudflare.com/ajax/libs/overlayscrollbars/2.10.0/browser/overlayscrollbars.browser.es5.min.js"
)
sselink = Script(src="https://unpkg.com/[email protected]/sse.js")

app, rt = fast_app(
htmlkw={"cls": "grid h-full"},
Expand All @@ -55,6 +60,7 @@
highlight_js_theme,
overlayscrollbars_link,
overlayscrollbars_js,
sselink,
),
)
vespa_app: Vespa = get_vespa_app()
Expand All @@ -64,6 +70,16 @@
max_size=1000
) # Map from query_id to boolean value - False if not all results are ready.
thread_pool = ThreadPoolExecutor()
# Gemini config

genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
GEMINI_SYSTEM_PROMPT = """If the user query is a question, try your best to answer it based on the provided images.
If the user query is not an obvious question, reply with 'No question detected.'. Your response should be HTML formatted.
This means that newlines will be replaced with <br> tags, bold text will be enclosed in <b> tags, and so on.
"""
gemini_model = genai.GenerativeModel(
"gemini-1.5-flash-8b", system_instruction=GEMINI_SYSTEM_PROMPT
)


@app.on_event("startup")
Expand Down Expand Up @@ -122,7 +138,9 @@ def get(request):
# Show the loading message if a query is provided
return Layout(
Main(Search(request), data_overlayscrollbars_initialize=True, cls="border-t"),
Aside(ChatResult(), cls="border-t border-l"),
Aside(
ChatResult(query_id=query_id, query=query_value), cls="border-t border-l"
),
) # Show SearchBox and Loading message initially


Expand Down Expand Up @@ -237,6 +255,40 @@ async def get_sim_map(query_id: str, idx: int, token: str):
)


async def message_generator(query_id: str, query: str):
result = None
while result is None:
result = result_cache.get(query_id)
await asyncio.sleep(0.5)
search_results = get_results_children(result)
images = [result["fields"]["full_image"] for result in search_results]
# from b64 to PIL image
images = [Image.open(io.BytesIO(base64.b64decode(img))) for img in images]

# If newlines are present in the response, the connection will be closed.
def replace_newline_with_br(text):
return text.replace("\n", "<br>")

response_text = ""
async for chunk in await gemini_model.generate_content_async(
images + ["\n\n Query: ", query], stream=True
):
if chunk.text:
response_text += chunk.text
response_text = replace_newline_with_br(response_text)
yield f"event: message\ndata: {response_text}\n\n"
await asyncio.sleep(0.5)
yield "event: close\ndata: \n\n"


@app.get("/get-message")
async def get_message(query_id: str, query: str):
return StreamingResponse(
message_generator(query_id=query_id, query=query),
media_type="text/event-stream",
)


@rt("/app")
def get():
return Layout(Main(Div(P(f"Connected to Vespa at {vespa_app.url}"), cls="p-4")))
Expand Down
3 changes: 2 additions & 1 deletion visual-retrieval-colpali/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ license = { text = "Apache-2.0" }
dependencies = [
"python-fasthtml",
"huggingface-hub",
"pyvespa@git+https://github.com/vespa-engine/pyvespa",
"pyvespa>=0.50.0",
"vespacli",
"torch",
"vidore-benchmark[interpretability]>=4.0.0,<5.0.0",
Expand All @@ -18,6 +18,7 @@ dependencies = [
"setuptools",
"python-dotenv",
"shad4fast>=1.2.1",
"google-generativeai>=0.7.2"
]

# dev-dependencies
Expand Down
Loading

0 comments on commit 2f0e7d0

Please sign in to comment.