-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate PDF extraction to sem_extract #69
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import re | ||
from typing import Any | ||
|
||
import fitz | ||
import pandas as pd | ||
|
||
import lotus | ||
|
@@ -77,15 +78,19 @@ def filter_formatter_cot( | |
cot = cot_reasoning[idx] | ||
messages.extend( | ||
[ | ||
user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), | ||
user_message_formatter( | ||
ex_multimodal_data, f"Claim: {user_instruction}" | ||
), | ||
{ | ||
"role": "assistant", | ||
"content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}", | ||
}, | ||
] | ||
) | ||
|
||
messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) | ||
messages.append( | ||
user_message_formatter(multimodal_data, f"Claim: {user_instruction}") | ||
) | ||
return messages | ||
|
||
|
||
|
@@ -102,7 +107,9 @@ def filter_formatter_zs_cot( | |
{"role": "system", "content": sys_instruction}, | ||
] | ||
|
||
messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) | ||
messages.append( | ||
user_message_formatter(multimodal_data, f"Claim: {user_instruction}") | ||
) | ||
return messages | ||
|
||
|
||
|
@@ -117,7 +124,11 @@ def filter_formatter( | |
if cot_reasoning: | ||
assert examples_multimodal_data is not None and examples_answer is not None | ||
return filter_formatter_cot( | ||
multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning | ||
multimodal_data, | ||
user_instruction, | ||
examples_multimodal_data, | ||
examples_answer, | ||
cot_reasoning, | ||
) | ||
elif strategy == "zs-cot": | ||
return filter_formatter_zs_cot(multimodal_data, user_instruction) | ||
|
@@ -133,18 +144,24 @@ def filter_formatter( | |
|
||
if examples_multimodal_data: | ||
assert examples_answer is not None | ||
assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list) | ||
assert isinstance(examples_multimodal_data, list) and isinstance( | ||
examples_answer, list | ||
) | ||
for i in range(len(examples_multimodal_data)): | ||
ex_multimodal_data = examples_multimodal_data[i] | ||
ex_ans = examples_answer[i] | ||
messages.extend( | ||
[ | ||
user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), | ||
user_message_formatter( | ||
ex_multimodal_data, f"Claim: {user_instruction}" | ||
), | ||
{"role": "assistant", "content": str(ex_ans)}, | ||
] | ||
) | ||
|
||
messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) | ||
messages.append( | ||
user_message_formatter(multimodal_data, f"Claim: {user_instruction}") | ||
) | ||
return messages | ||
|
||
|
||
|
@@ -178,7 +195,9 @@ def map_formatter_cot( | |
] | ||
) | ||
|
||
messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) | ||
messages.append( | ||
user_message_formatter(multimodal_data, f"Instruction: {user_instruction}") | ||
) | ||
return messages | ||
|
||
|
||
|
@@ -195,7 +214,9 @@ def map_formatter_zs_cot( | |
{"role": "system", "content": sys_instruction}, | ||
] | ||
|
||
messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) | ||
messages.append( | ||
user_message_formatter(multimodal_data, f"Instruction: {user_instruction}") | ||
) | ||
return messages | ||
|
||
|
||
|
@@ -210,7 +231,11 @@ def map_formatter( | |
if cot_reasoning: | ||
assert examples_multimodal_data is not None and examples_answer is not None | ||
return map_formatter_cot( | ||
multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning | ||
multimodal_data, | ||
user_instruction, | ||
examples_multimodal_data, | ||
examples_answer, | ||
cot_reasoning, | ||
) | ||
elif strategy == "zs-cot": | ||
return map_formatter_zs_cot(multimodal_data, user_instruction) | ||
|
@@ -228,21 +253,29 @@ def map_formatter( | |
for ex_df_txt, ex_ans in zip(examples_multimodal_data, examples_answer): | ||
messages.extend( | ||
[ | ||
user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"), | ||
user_message_formatter( | ||
ex_df_txt, f"Instruction: {user_instruction}" | ||
), | ||
{"role": "assistant", "content": str(ex_ans)}, | ||
] | ||
) | ||
|
||
messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) | ||
messages.append( | ||
user_message_formatter(multimodal_data, f"Instruction: {user_instruction}") | ||
) | ||
return messages | ||
|
||
|
||
def extract_formatter( | ||
multimodal_data: dict[str, Any], output_cols: dict[str, str | None], extract_quotes: bool = True | ||
multimodal_data: dict[str, Any], | ||
output_cols: dict[str, str | None], | ||
extract_quotes: bool = True, | ||
) -> list[dict[str, str]]: | ||
output_col_names = list(output_cols.keys()) | ||
# Set the description to be the key if no value is provided | ||
output_cols_with_desc: dict[str, str] = {col: col if desc is None else desc for col, desc in output_cols.items()} | ||
output_cols_with_desc: dict[str, str] = { | ||
col: col if desc is None else desc for col, desc in output_cols.items() | ||
} | ||
|
||
all_fields = output_col_names | ||
if extract_quotes: | ||
|
@@ -271,10 +304,14 @@ def df2text(df: pd.DataFrame, cols: list[str]) -> list[str]: | |
"""Formats the given DataFrame into a string containing info from cols.""" | ||
|
||
def custom_format_row(x: pd.Series, cols: list[str]) -> str: | ||
return "".join([f"[{cols[i].capitalize()}]: «{x[cols[i]]}»\n" for i in range(len(cols))]) | ||
return "".join( | ||
[f"[{cols[i].capitalize()}]: «{x[cols[i]]}»\n" for i in range(len(cols))] | ||
) | ||
|
||
def clean_and_escape_column_name(column_name: str) -> str: | ||
clean_name = re.sub(r"[^\w]", "", column_name) # Remove spaces and special characters | ||
clean_name = re.sub( | ||
r"[^\w]", "", column_name | ||
) # Remove spaces and special characters | ||
return clean_name | ||
|
||
# take cols that are in df | ||
|
@@ -286,7 +323,9 @@ def clean_and_escape_column_name(column_name: str) -> str: | |
formatted_rows: list[str] = [] | ||
|
||
if lotus.settings.serialization_format == SerializationFormat.DEFAULT: | ||
formatted_rows = projected_df.apply(lambda x: custom_format_row(x, cols), axis=1).tolist() | ||
formatted_rows = projected_df.apply( | ||
lambda x: custom_format_row(x, cols), axis=1 | ||
).tolist() | ||
elif lotus.settings.serialization_format == SerializationFormat.JSON: | ||
formatted_rows = projected_df.to_json(orient="records", lines=True).splitlines() | ||
elif lotus.settings.serialization_format == SerializationFormat.XML: | ||
|
@@ -298,33 +337,71 @@ def clean_and_escape_column_name(column_name: str) -> str: | |
"You can install it with the following command:\n\n" | ||
" pip install 'lotus-ai[xml]'" | ||
) | ||
projected_df = projected_df.rename(columns=lambda x: clean_and_escape_column_name(x)) | ||
full_xml = projected_df.to_xml(root_name="data", row_name="row", pretty_print=False, index=False) | ||
projected_df = projected_df.rename( | ||
columns=lambda x: clean_and_escape_column_name(x) | ||
) | ||
full_xml = projected_df.to_xml( | ||
root_name="data", row_name="row", pretty_print=False, index=False | ||
) | ||
root = ET.fromstring(full_xml) | ||
formatted_rows = [ET.tostring(row, encoding="unicode", method="xml") for row in root.findall("row")] | ||
formatted_rows = [ | ||
ET.tostring(row, encoding="unicode", method="xml") | ||
for row in root.findall("row") | ||
] | ||
|
||
return formatted_rows | ||
|
||
|
||
def df2multimodal_info(df: pd.DataFrame, cols: list[str]) -> list[dict[str, Any]]: | ||
def pdf2text(file_path: str) -> str: | ||
try: | ||
with fitz.open(file_path) as doc: | ||
return " ".join(page.get_text() for page in doc) | ||
except Exception as e: | ||
lotus.logger.debug(f"Error while processing pdf at file path {file_path}: {e}") | ||
return "" | ||
|
||
|
||
def df2multimodal_info( | ||
df: pd.DataFrame, cols: list[str], infer_pdfs: bool = False | ||
) -> list[dict[str, Any]]: | ||
""" | ||
Formats the given DataFrame into a string containing info from cols. | ||
Return a list of dictionaries, each containing text and image data. | ||
""" | ||
|
||
# We want to modify `df` to replace PDF paths with their text, but without modifying the source | ||
df_copy = df.copy() | ||
|
||
image_cols = [col for col in cols if isinstance(df[col].dtype, ImageDtype)] | ||
if infer_pdfs: | ||
pdf_cols = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think we should have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case we may not need the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there are a few potential problems in implementing a |
||
col | ||
for col in cols | ||
if col not in image_cols | ||
and df[col].apply(lambda x: isinstance(x, str) and x.endswith(".pdf")).all() | ||
] | ||
for col in pdf_cols: | ||
df_copy[col] = df_copy[col].apply(pdf2text) | ||
|
||
text_cols = [col for col in cols if col not in image_cols] | ||
text_rows = df2text(df, text_cols) | ||
|
||
text_rows = df2text(df_copy, text_cols) | ||
multimodal_data = [ | ||
{ | ||
"text": text_rows[i], | ||
"image": {col.capitalize(): df[col].array.get_image(i, "base64") for col in image_cols}, | ||
"image": { | ||
col.capitalize(): df[col].array.get_image(i, "base64") | ||
for col in image_cols | ||
}, | ||
} | ||
for i in range(len(df)) | ||
] | ||
return multimodal_data | ||
|
||
|
||
def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, Any]]) -> list[dict[str, Any]]: | ||
def merge_multimodal_info( | ||
first: list[dict[str, Any]], second: list[dict[str, Any]] | ||
) -> list[dict[str, Any]]: | ||
""" | ||
Merges two multimodal info lists into one. Each row of first is merged with each row of second. | ||
|
||
|
@@ -337,9 +414,11 @@ def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, An | |
""" | ||
return [ | ||
{ | ||
"text": f"{first[i]['text']}\n{second[j]['text']}" | ||
if first[i]["text"] != "" and second[j]["text"] != "" | ||
else first[i]["text"] + second[j]["text"], | ||
"text": ( | ||
f"{first[i]['text']}\n{second[j]['text']}" | ||
if first[i]["text"] != "" and second[j]["text"] != "" | ||
else first[i]["text"] + second[j]["text"] | ||
), | ||
"image": {**first[i]["image"], **second[j]["image"]}, | ||
} | ||
for i in range(len(first)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,5 @@ numpy==1.26.4 | |
pandas==2.2.2 | ||
sentence-transformers==3.0.1 | ||
tiktoken==0.7.0 | ||
tqdm==4.66.4 | ||
tqdm==4.66.4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
PyMuPDF==1.25.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use
lotus.logger.error
rather thanlotus.logger.debug
here.