Skip to content

Commit

Permalink
Merge pull request #1523 from vespa-engine/thomasht86/consolidate-upd…
Browse files Browse the repository at this point in the history
…ates

(colpalidemo) consolidate updates
  • Loading branch information
ldalves authored Oct 23, 2024
2 parents 62983d8 + 0c77f77 commit 04672c2
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 285 deletions.
3 changes: 2 additions & 1 deletion visual-retrieval-colpali/.gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
.sesskey
.venv/
__pycache__/
ipynb_checkpoints/
.python-version
.env
template/
*.json
output/
pdfs/
pdfs/
32 changes: 25 additions & 7 deletions visual-retrieval-colpali/backend/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ def gen_similarity_maps(
if vespa_sim_maps:
print("Using provided similarity maps")
# A sim map looks like this:
# "similarities": [
# "quantized": [
# {
# "address": {
# "patch": "0",
# "querytoken": "0"
# },
# "value": 1.2599412202835083
# "value": 12, # score in range [-128, 127]
# },
# ... and so on.
# Now turn these into a tensor of same shape as previous similarity map
Expand All @@ -189,7 +189,7 @@ def gen_similarity_maps(
)
)
for idx, vespa_sim_map in enumerate(vespa_sim_maps):
for cell in vespa_sim_map["similarities"]["cells"]:
for cell in vespa_sim_map["quantized"]["cells"]:
patch = int(cell["address"]["patch"])
# if dummy model then just use 1024 as the image_seq_length

Expand Down Expand Up @@ -359,7 +359,7 @@ async def query_vespa_default(
start = time.perf_counter()
response: VespaQueryResponse = await session.query(
body={
"yql": "select id,title,url,full_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
"yql": "select id,title,url,blur_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
"ranking": "default",
"query": query,
"timeout": timeout,
Expand Down Expand Up @@ -392,7 +392,7 @@ async def query_vespa_bm25(
start = time.perf_counter()
response: VespaQueryResponse = await session.query(
body={
"yql": "select id,title,url,full_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
"yql": "select id,title,url,blur_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
"ranking": "bm25",
"query": query,
"timeout": timeout,
Expand Down Expand Up @@ -472,7 +472,7 @@ async def query_vespa_nearest_neighbor(
**query_tensors,
"presentation.timing": True,
# if we use rank({nn_string}, userQuery()), dynamic summary doesn't work, see https://github.com/vespa-engine/vespa/issues/28704
"yql": f"select id,title,snippet,text,url,full_image,page_number,summaryfeatures from pdf_page where {nn_string} or userQuery()",
"yql": f"select id,title,snippet,text,url,blur_image,page_number,summaryfeatures from pdf_page where {nn_string} or userQuery()",
"ranking.profile": "retrieval-and-rerank",
"timeout": timeout,
"hits": hits,
Expand All @@ -492,6 +492,24 @@ def is_special_token(token: str) -> bool:
return True
return False

async def get_full_image_from_vespa(
app: Vespa,
id: str) -> str:
async with app.asyncio(connections=1, total_timeout=120) as session:
start = time.perf_counter()
response: VespaQueryResponse = await session.query(
body={
"yql": f"select full_image from pdf_page where id contains \"{id}\"",
"ranking": "unranked",
"presentation.timing": True,
},
)
assert response.is_successful(), response.json
stop = time.perf_counter()
print(
f"Getting image from Vespa took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s"
)
return response.json["root"]["children"][0]["fields"]["full_image"]

async def get_result_from_query(
app: Vespa,
Expand Down Expand Up @@ -538,7 +556,7 @@ def add_sim_maps_to_result(
imgs: List[str] = []
vespa_sim_maps: List[str] = []
for single_result in result["root"]["children"]:
img = single_result["fields"]["full_image"]
img = single_result["fields"]["blur_image"]
if img:
imgs.append(img)
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
Expand Down
26 changes: 17 additions & 9 deletions visual-retrieval-colpali/frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,13 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):

def SampleQueries():
sample_queries = [
"Percentage of non-fresh water as source?",
"Policies related to nature risk?",
"How much of produced water is recycled?",
"Proportion of female new hires 2021-2023?",
"Total amount of fixed salaries paid in 2023?",
"What is the percentage distribution of employees with performance-based pay relative to the limit in 2023?",
"What is the breakdown of management costs by investment strategy in 2023?",
"2023 profit loss portfolio",
"net cash flow operating activities",
"fund currency basket returns",
]

query_badges = []
Expand Down Expand Up @@ -163,13 +167,13 @@ def Hero():
return Div(
H1(
"Vespa.ai + ColPali",
cls="text-5xl md:text-7xl font-bold tracking-wide md:tracking-wider bg-clip-text text-transparent bg-gradient-to-r from-black to-gray-700 dark:from-white dark:to-gray-300 animate-fade-in",
cls="text-4xl md:text-7xl font-bold tracking-wide md:tracking-wider bg-clip-text text-transparent bg-gradient-to-r from-black to-gray-700 dark:from-white dark:to-gray-300 animate-fade-in",
),
P(
"Efficient Document Retrieval with Vision Language Models",
cls="text-lg md:text-2xl text-muted-foreground md:tracking-wide",
),
cls="grid gap-5 text-center",
cls="grid gap-5 text-center pt-5",
)


Expand All @@ -179,7 +183,7 @@ def Home():
Hero(),
SearchBox(with_border=True),
SampleQueries(),
cls="grid gap-8 -mt-[34vh]",
cls="grid gap-8 md:-mt-[34vh]", # Negative margin only on medium and larger screens
),
cls="grid w-full h-full max-w-screen-md items-center gap-4 mx-auto",
)
Expand Down Expand Up @@ -252,7 +256,7 @@ def SearchResult(results: list, query_id: Optional[str] = None):
result_items = []
for idx, result in enumerate(results):
fields = result["fields"] # Extract the 'fields' part of each result
full_image_base64 = f"data:image/jpeg;base64,{fields['full_image']}"
blur_image_base64 = f"data:image/jpeg;base64,{fields['blur_image']}"

# Filter sim_map fields that are words with 4 or more characters
sim_map_fields = {
Expand Down Expand Up @@ -288,7 +292,7 @@ def SearchResult(results: list, query_id: Optional[str] = None):
"Reset",
variant="outline",
size="sm",
data_image_src=full_image_base64,
data_image_src=blur_image_base64,
cls="reset-button pointer-events-auto font-mono text-xs h-5 rounded-none px-2",
)

Expand All @@ -314,7 +318,11 @@ def SearchResult(results: list, query_id: Optional[str] = None):
Div(
Div(
Img(
src=full_image_base64,
src=blur_image_base64,
hx_get=f"/full_image?docid={fields['id']}&query_id={query_id}&idx={idx}",
style="filter: blur(5px);",
hx_trigger="load",
hx_swap="outerHTML",
alt=fields["title"],
cls="result-image w-full h-full object-contain",
),
Expand Down
60 changes: 55 additions & 5 deletions visual-retrieval-colpali/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import os

from fasthtml.common import *
from shad4fast import *
Expand All @@ -14,6 +15,7 @@
get_query_embeddings_and_token_map,
get_result_from_query,
is_special_token,
get_full_image_from_vespa,
)
from backend.modelmanager import ModelManager
from backend.vespa_app import get_vespa_app
Expand Down Expand Up @@ -76,6 +78,7 @@
GEMINI_SYSTEM_PROMPT = """If the user query is a question, try your best to answer it based on the provided images.
If the user query is not an obvious question, reply with 'No question detected.'. Your response should be HTML formatted.
This means that newlines will be replaced with <br> tags, bold text will be enclosed in <b> tags, and so on.
But, you should NOT include backticks (`) or HTML tags in your response.
"""
gemini_model = genai.GenerativeModel(
"gemini-1.5-flash-8b", system_instruction=GEMINI_SYSTEM_PROMPT
Expand Down Expand Up @@ -249,19 +252,66 @@ async def get_sim_map(query_id: str, idx: int, token: str):
sim_map_b64 = search_results[idx]["fields"].get(sim_map_key, None)
if sim_map_b64 is None:
return SimMapButtonPoll(query_id=query_id, idx=idx, token=token)
sim_map_img_src = f"data:image/jpeg;base64,{sim_map_b64}"
sim_map_img_src = f"data:image/png;base64,{sim_map_b64}"
return SimMapButtonReady(
query_id=query_id, idx=idx, token=token, img_src=sim_map_img_src
)


async def update_full_image_cache(docid: str, query_id: str, idx: int, image_data: str):
result = result_cache.get(query_id)
if result is None:
await asyncio.sleep(0.5)
return
search_results = get_results_children(result)
# Check if idx exists in list of children
if idx >= len(search_results):
await asyncio.sleep(0.5)
return
search_results[idx]["fields"]["full_image"] = image_data
result_cache.set(query_id, result)
return


@app.get("/full_image")
async def full_image(docid: str, query_id: str, idx: int):
"""
Endpoint to get the full quality image for a given result id.
"""
image_data = await get_full_image_from_vespa(vespa_app, docid)
# Update the cache with the full image data asynchronously to not block the request
asyncio.create_task(update_full_image_cache(docid, query_id, idx, image_data))
# Decode the base64 image data
# image_data = base64.b64decode(image_data)
image_data = "data:image/jpeg;base64," + image_data

return Img(
src=image_data,
alt="something",
cls="result-image w-full h-full object-contain",
)


async def message_generator(query_id: str, query: str):
images = []
result = None
while result is None:
all_images_ready = False
while not all_images_ready:
result = result_cache.get(query_id)
await asyncio.sleep(0.5)
search_results = get_results_children(result)
images = [result["fields"]["full_image"] for result in search_results]
if result is None:
await asyncio.sleep(0.1)
continue
search_results = get_results_children(result)
for single_result in search_results:
img = single_result["fields"].get("full_image", None)
if img is not None:
images.append(img)
if len(images) == len(search_results):
all_images_ready = True
break
else:
await asyncio.sleep(0.1)

# from b64 to PIL image
images = [Image.open(io.BytesIO(base64.b64decode(img))) for img in images]

Expand Down
63 changes: 32 additions & 31 deletions visual-retrieval-colpali/prepare_feed_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,22 +726,37 @@ def float_to_binary_embedding(float_query_embedding: dict) -> dict:
],
)

# Define similarity functions used in all rank profiles
mapfunctions = [
Function(
name="similarities", # computes similarity scores between each query token and image patch
expression="""
sum(
query(qt) * unpack_bits(attribute(embedding)), v
)
""",
),
Function(
name="normalized", # normalizes the similarity scores to [-1, 1]
expression="""
(similarities - reduce(similarities, min)) / (reduce((similarities - reduce(similarities, min)), max)) * 2 - 1
""",
),
Function(
name="quantized", # quantizes the normalized similarity scores to signed 8-bit integers [-128, 127]
expression="""
cell_cast(normalized * 127.999, int8)
""",
),
]

# Define the 'bm25' rank profile
colpali_bm25_profile = RankProfile(
name="bm25",
inputs=[("query(qt)", "tensor<float>(querytoken{}, v[128])")],
first_phase="bm25(title) + bm25(text)",
functions=[
Function(
name="similarities",
expression="""
sum(
query(qt) * unpack_bits(attribute(embedding)), v
)
""",
),
],
summary_features=["similarities"],
functions=mapfunctions,
summary_features=["quantized"],
)
colpali_schema.add_rank_profile(colpali_bm25_profile)

Expand All @@ -751,7 +766,8 @@ def float_to_binary_embedding(float_query_embedding: dict) -> dict:
inputs=[("query(qt)", "tensor<float>(querytoken{}, v[128])")],
first_phase="bm25_score",
second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10),
functions=[
functions=mapfunctions
+ [
Function(
name="max_sim",
expression="""
Expand All @@ -767,16 +783,8 @@ def float_to_binary_embedding(float_query_embedding: dict) -> dict:
""",
),
Function(name="bm25_score", expression="bm25(title) + bm25(text)"),
Function(
name="similarities",
expression="""
sum(
query(qt) * unpack_bits(attribute(embedding)), v
)
""",
),
],
summary_features=["similarities"],
summary_features=["quantized"],
)
colpali_schema.add_rank_profile(colpali_profile)

Expand All @@ -798,7 +806,8 @@ def float_to_binary_embedding(float_query_embedding: dict) -> dict:
inputs=input_query_tensors,
first_phase="max_sim_binary",
second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10),
functions=[
functions=mapfunctions
+ [
Function(
name="max_sim",
expression="""
Expand Down Expand Up @@ -827,16 +836,8 @@ def float_to_binary_embedding(float_query_embedding: dict) -> dict:
)
""",
),
Function(
name="similarities",
expression="""
sum(
query(qt) * unpack_bits(attribute(embedding)), v
)
""",
),
],
summary_features=["similarities"],
summary_features=["quantized"],
)
colpali_schema.add_rank_profile(colpali_retrieval_profile)

Expand Down
Loading

0 comments on commit 04672c2

Please sign in to comment.