Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate PDF extraction to sem_extract #69

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from lotus.cache import operator_cache
from lotus.models import LM
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput
from lotus.types import (
LMOutput,
SemanticExtractOutput,
SemanticExtractPostprocessOutput,
)
from lotus.utils import show_safe_mode

from .postprocessors import extract_postprocess
Expand All @@ -17,7 +21,9 @@ def sem_extract(
model: LM,
output_cols: dict[str, str | None],
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
postprocessor: Callable[
[list[str]], SemanticExtractPostprocessOutput
] = extract_postprocess,
safe_mode: bool = False,
progress_bar_desc: str = "Extracting",
) -> SemanticExtractOutput:
Expand All @@ -39,7 +45,9 @@ def sem_extract(
for doc in docs:
prompt = task_instructions.extract_formatter(doc, output_cols, extract_quotes)
lotus.logger.debug(f"input to model: {prompt}")
lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}")
lotus.logger.debug(
f"inputs content to model: {[x.get('content') for x in prompt]}"
)
inputs.append(prompt)

# check if safe_mode is enabled
Expand All @@ -49,7 +57,11 @@ def sem_extract(
show_safe_mode(estimated_cost, estimated_LM_calls)

# call model
lm_output: LMOutput = model(inputs, response_format={"type": "json_object"}, progress_bar_desc=progress_bar_desc)
lm_output: LMOutput = model(
inputs,
response_format={"type": "json_object"},
progress_bar_desc=progress_bar_desc,
)

# post process results
postprocess_output = postprocessor(lm_output.outputs)
Expand Down Expand Up @@ -78,8 +90,11 @@ def __call__(
input_cols: list[str],
output_cols: dict[str, str | None],
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
postprocessor: Callable[
[list[str]], SemanticExtractPostprocessOutput
] = extract_postprocess,
return_raw_outputs: bool = False,
infer_pdfs: bool = False,
safe_mode: bool = False,
progress_bar_desc: str = "Extracting",
) -> pd.DataFrame:
Expand All @@ -106,7 +121,9 @@ def __call__(
if column not in self._obj.columns:
raise ValueError(f"Column {column} not found in DataFrame")

multimodal_data = task_instructions.df2multimodal_info(self._obj, input_cols)
multimodal_data = task_instructions.df2multimodal_info(
self._obj, input_cols, infer_pdfs
)

out = sem_extract(
docs=multimodal_data,
Expand Down
133 changes: 106 additions & 27 deletions lotus/templates/task_instructions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from typing import Any

import fitz
import pandas as pd

import lotus
Expand Down Expand Up @@ -77,15 +78,19 @@ def filter_formatter_cot(
cot = cot_reasoning[idx]
messages.extend(
[
user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"),
user_message_formatter(
ex_multimodal_data, f"Claim: {user_instruction}"
),
{
"role": "assistant",
"content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}",
},
]
)

messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
messages.append(
user_message_formatter(multimodal_data, f"Claim: {user_instruction}")
)
return messages


Expand All @@ -102,7 +107,9 @@ def filter_formatter_zs_cot(
{"role": "system", "content": sys_instruction},
]

messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
messages.append(
user_message_formatter(multimodal_data, f"Claim: {user_instruction}")
)
return messages


Expand All @@ -117,7 +124,11 @@ def filter_formatter(
if cot_reasoning:
assert examples_multimodal_data is not None and examples_answer is not None
return filter_formatter_cot(
multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning
multimodal_data,
user_instruction,
examples_multimodal_data,
examples_answer,
cot_reasoning,
)
elif strategy == "zs-cot":
return filter_formatter_zs_cot(multimodal_data, user_instruction)
Expand All @@ -133,18 +144,24 @@ def filter_formatter(

if examples_multimodal_data:
assert examples_answer is not None
assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list)
assert isinstance(examples_multimodal_data, list) and isinstance(
examples_answer, list
)
for i in range(len(examples_multimodal_data)):
ex_multimodal_data = examples_multimodal_data[i]
ex_ans = examples_answer[i]
messages.extend(
[
user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"),
user_message_formatter(
ex_multimodal_data, f"Claim: {user_instruction}"
),
{"role": "assistant", "content": str(ex_ans)},
]
)

messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
messages.append(
user_message_formatter(multimodal_data, f"Claim: {user_instruction}")
)
return messages


Expand Down Expand Up @@ -178,7 +195,9 @@ def map_formatter_cot(
]
)

messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"))
messages.append(
user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")
)
return messages


Expand All @@ -195,7 +214,9 @@ def map_formatter_zs_cot(
{"role": "system", "content": sys_instruction},
]

messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"))
messages.append(
user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")
)
return messages


Expand All @@ -210,7 +231,11 @@ def map_formatter(
if cot_reasoning:
assert examples_multimodal_data is not None and examples_answer is not None
return map_formatter_cot(
multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning
multimodal_data,
user_instruction,
examples_multimodal_data,
examples_answer,
cot_reasoning,
)
elif strategy == "zs-cot":
return map_formatter_zs_cot(multimodal_data, user_instruction)
Expand All @@ -228,21 +253,29 @@ def map_formatter(
for ex_df_txt, ex_ans in zip(examples_multimodal_data, examples_answer):
messages.extend(
[
user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"),
user_message_formatter(
ex_df_txt, f"Instruction: {user_instruction}"
),
{"role": "assistant", "content": str(ex_ans)},
]
)

messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"))
messages.append(
user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")
)
return messages


def extract_formatter(
multimodal_data: dict[str, Any], output_cols: dict[str, str | None], extract_quotes: bool = True
multimodal_data: dict[str, Any],
output_cols: dict[str, str | None],
extract_quotes: bool = True,
) -> list[dict[str, str]]:
output_col_names = list(output_cols.keys())
# Set the description to be the key if no value is provided
output_cols_with_desc: dict[str, str] = {col: col if desc is None else desc for col, desc in output_cols.items()}
output_cols_with_desc: dict[str, str] = {
col: col if desc is None else desc for col, desc in output_cols.items()
}

all_fields = output_col_names
if extract_quotes:
Expand Down Expand Up @@ -271,10 +304,14 @@ def df2text(df: pd.DataFrame, cols: list[str]) -> list[str]:
"""Formats the given DataFrame into a string containing info from cols."""

def custom_format_row(x: pd.Series, cols: list[str]) -> str:
return "".join([f"[{cols[i].capitalize()}]: «{x[cols[i]]}»\n" for i in range(len(cols))])
return "".join(
[f"[{cols[i].capitalize()}]: «{x[cols[i]]}»\n" for i in range(len(cols))]
)

def clean_and_escape_column_name(column_name: str) -> str:
clean_name = re.sub(r"[^\w]", "", column_name) # Remove spaces and special characters
clean_name = re.sub(
r"[^\w]", "", column_name
) # Remove spaces and special characters
return clean_name

# take cols that are in df
Expand All @@ -286,7 +323,9 @@ def clean_and_escape_column_name(column_name: str) -> str:
formatted_rows: list[str] = []

if lotus.settings.serialization_format == SerializationFormat.DEFAULT:
formatted_rows = projected_df.apply(lambda x: custom_format_row(x, cols), axis=1).tolist()
formatted_rows = projected_df.apply(
lambda x: custom_format_row(x, cols), axis=1
).tolist()
elif lotus.settings.serialization_format == SerializationFormat.JSON:
formatted_rows = projected_df.to_json(orient="records", lines=True).splitlines()
elif lotus.settings.serialization_format == SerializationFormat.XML:
Expand All @@ -298,33 +337,71 @@ def clean_and_escape_column_name(column_name: str) -> str:
"You can install it with the following command:\n\n"
" pip install 'lotus-ai[xml]'"
)
projected_df = projected_df.rename(columns=lambda x: clean_and_escape_column_name(x))
full_xml = projected_df.to_xml(root_name="data", row_name="row", pretty_print=False, index=False)
projected_df = projected_df.rename(
columns=lambda x: clean_and_escape_column_name(x)
)
full_xml = projected_df.to_xml(
root_name="data", row_name="row", pretty_print=False, index=False
)
root = ET.fromstring(full_xml)
formatted_rows = [ET.tostring(row, encoding="unicode", method="xml") for row in root.findall("row")]
formatted_rows = [
ET.tostring(row, encoding="unicode", method="xml")
for row in root.findall("row")
]

return formatted_rows


def df2multimodal_info(df: pd.DataFrame, cols: list[str]) -> list[dict[str, Any]]:
def pdf2text(file_path: str) -> str:
try:
with fitz.open(file_path) as doc:
return " ".join(page.get_text() for page in doc)
except Exception as e:
lotus.logger.debug(f"Error while processing pdf at file path {file_path}: {e}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use lotus.logger.error rather than lotus.logger.debug here.

return ""


def df2multimodal_info(
df: pd.DataFrame, cols: list[str], infer_pdfs: bool = False
) -> list[dict[str, Any]]:
"""
Formats the given DataFrame into a string containing info from cols.
Return a list of dictionaries, each containing text and image data.
"""

# We want to modify `df` to replace PDF paths with their text, but without modifying the source
df_copy = df.copy()

image_cols = [col for col in cols if isinstance(df[col].dtype, ImageDtype)]
if infer_pdfs:
pdf_cols = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should have a PDFDtype? To support operating over images we added ImageDtype so I wonder if it makes sense to add a custom pandas type here too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case we may not need the infer_pdfs flag, just as we do not have an infer_images flag.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are a few potential problems in implementing a PDFDtype. ImageDType worked off of the Image object from pillow, but there is no equivalent standard we can use for PDFs, except maybe the PyMuPDF.Document object. Also, we can more objectively tell if a value is an image than if it is a path for a PDF. Perhaps, for instance, a cell simply contains titles in PDF format and that is actually the way the user wants it interpreted. It's less clear than it is with images.

col
for col in cols
if col not in image_cols
and df[col].apply(lambda x: isinstance(x, str) and x.endswith(".pdf")).all()
]
for col in pdf_cols:
df_copy[col] = df_copy[col].apply(pdf2text)

text_cols = [col for col in cols if col not in image_cols]
text_rows = df2text(df, text_cols)

text_rows = df2text(df_copy, text_cols)
multimodal_data = [
{
"text": text_rows[i],
"image": {col.capitalize(): df[col].array.get_image(i, "base64") for col in image_cols},
"image": {
col.capitalize(): df[col].array.get_image(i, "base64")
for col in image_cols
},
}
for i in range(len(df))
]
return multimodal_data


def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, Any]]) -> list[dict[str, Any]]:
def merge_multimodal_info(
first: list[dict[str, Any]], second: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""
Merges two multimodal info lists into one. Each row of first is merged with each row of second.

Expand All @@ -337,9 +414,11 @@ def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, An
"""
return [
{
"text": f"{first[i]['text']}\n{second[j]['text']}"
if first[i]["text"] != "" and second[j]["text"] != ""
else first[i]["text"] + second[j]["text"],
"text": (
f"{first[i]['text']}\n{second[j]['text']}"
if first[i]["text"] != "" and second[j]["text"] != ""
else first[i]["text"] + second[j]["text"]
),
"image": {**first[i]["image"], **second[j]["image"]},
}
for i in range(len(first))
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ numpy==1.26.4
pandas==2.2.2
sentence-transformers==3.0.1
tiktoken==0.7.0
tqdm==4.66.4
tqdm==4.66.4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyMuPDF can just be an optional dependency. For reference, you can see how lxml dependency is handled.

PyMuPDF==1.25.1
Loading