From c28d0faa8a09e8ac55e639ba9a524c7886a5e14f Mon Sep 17 00:00:00 2001
From: Pratyay Pandey
Date: Sun, 5 Jan 2025 10:06:16 -0800
Subject: [PATCH] Integrate PDF extraction to sem_extract
---
lotus/sem_ops/sem_extract.py | 29 ++++--
lotus/templates/task_instructions.py | 133 +++++++++++++++++++++------
requirements.txt | 3 +-
3 files changed, 131 insertions(+), 34 deletions(-)
diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py
index ed9619e5..b0c2d3ec 100644
--- a/lotus/sem_ops/sem_extract.py
+++ b/lotus/sem_ops/sem_extract.py
@@ -6,7 +6,11 @@
from lotus.cache import operator_cache
from lotus.models import LM
from lotus.templates import task_instructions
-from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput
+from lotus.types import (
+ LMOutput,
+ SemanticExtractOutput,
+ SemanticExtractPostprocessOutput,
+)
from lotus.utils import show_safe_mode
from .postprocessors import extract_postprocess
@@ -17,7 +21,9 @@ def sem_extract(
model: LM,
output_cols: dict[str, str | None],
extract_quotes: bool = False,
- postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
+ postprocessor: Callable[
+ [list[str]], SemanticExtractPostprocessOutput
+ ] = extract_postprocess,
safe_mode: bool = False,
progress_bar_desc: str = "Extracting",
) -> SemanticExtractOutput:
@@ -39,7 +45,9 @@ def sem_extract(
for doc in docs:
prompt = task_instructions.extract_formatter(doc, output_cols, extract_quotes)
lotus.logger.debug(f"input to model: {prompt}")
- lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}")
+ lotus.logger.debug(
+ f"inputs content to model: {[x.get('content') for x in prompt]}"
+ )
inputs.append(prompt)
# check if safe_mode is enabled
@@ -49,7 +57,11 @@ def sem_extract(
show_safe_mode(estimated_cost, estimated_LM_calls)
# call model
- lm_output: LMOutput = model(inputs, response_format={"type": "json_object"}, progress_bar_desc=progress_bar_desc)
+ lm_output: LMOutput = model(
+ inputs,
+ response_format={"type": "json_object"},
+ progress_bar_desc=progress_bar_desc,
+ )
# post process results
postprocess_output = postprocessor(lm_output.outputs)
@@ -78,8 +90,11 @@ def __call__(
input_cols: list[str],
output_cols: dict[str, str | None],
extract_quotes: bool = False,
- postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
+ postprocessor: Callable[
+ [list[str]], SemanticExtractPostprocessOutput
+ ] = extract_postprocess,
return_raw_outputs: bool = False,
+ infer_pdfs: bool = False,
safe_mode: bool = False,
progress_bar_desc: str = "Extracting",
) -> pd.DataFrame:
@@ -106,7 +121,9 @@ def __call__(
if column not in self._obj.columns:
raise ValueError(f"Column {column} not found in DataFrame")
- multimodal_data = task_instructions.df2multimodal_info(self._obj, input_cols)
+ multimodal_data = task_instructions.df2multimodal_info(
+ self._obj, input_cols, infer_pdfs
+ )
out = sem_extract(
docs=multimodal_data,
diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py
index fc30efd9..4030f148 100644
--- a/lotus/templates/task_instructions.py
+++ b/lotus/templates/task_instructions.py
@@ -1,6 +1,7 @@
import re
from typing import Any
+import fitz
import pandas as pd
import lotus
@@ -77,7 +78,9 @@ def filter_formatter_cot(
cot = cot_reasoning[idx]
messages.extend(
[
- user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"),
+ user_message_formatter(
+ ex_multimodal_data, f"Claim: {user_instruction}"
+ ),
{
"role": "assistant",
"content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}",
@@ -85,7 +88,9 @@ def filter_formatter_cot(
]
)
- messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
+ messages.append(
+ user_message_formatter(multimodal_data, f"Claim: {user_instruction}")
+ )
return messages
@@ -102,7 +107,9 @@ def filter_formatter_zs_cot(
{"role": "system", "content": sys_instruction},
]
- messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
+ messages.append(
+ user_message_formatter(multimodal_data, f"Claim: {user_instruction}")
+ )
return messages
@@ -117,7 +124,11 @@ def filter_formatter(
if cot_reasoning:
assert examples_multimodal_data is not None and examples_answer is not None
return filter_formatter_cot(
- multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning
+ multimodal_data,
+ user_instruction,
+ examples_multimodal_data,
+ examples_answer,
+ cot_reasoning,
)
elif strategy == "zs-cot":
return filter_formatter_zs_cot(multimodal_data, user_instruction)
@@ -133,18 +144,24 @@ def filter_formatter(
if examples_multimodal_data:
assert examples_answer is not None
- assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list)
+ assert isinstance(examples_multimodal_data, list) and isinstance(
+ examples_answer, list
+ )
for i in range(len(examples_multimodal_data)):
ex_multimodal_data = examples_multimodal_data[i]
ex_ans = examples_answer[i]
messages.extend(
[
- user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"),
+ user_message_formatter(
+ ex_multimodal_data, f"Claim: {user_instruction}"
+ ),
{"role": "assistant", "content": str(ex_ans)},
]
)
- messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
+ messages.append(
+ user_message_formatter(multimodal_data, f"Claim: {user_instruction}")
+ )
return messages
@@ -178,7 +195,9 @@ def map_formatter_cot(
]
)
- messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"))
+ messages.append(
+ user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")
+ )
return messages
@@ -195,7 +214,9 @@ def map_formatter_zs_cot(
{"role": "system", "content": sys_instruction},
]
- messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"))
+ messages.append(
+ user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")
+ )
return messages
@@ -210,7 +231,11 @@ def map_formatter(
if cot_reasoning:
assert examples_multimodal_data is not None and examples_answer is not None
return map_formatter_cot(
- multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning
+ multimodal_data,
+ user_instruction,
+ examples_multimodal_data,
+ examples_answer,
+ cot_reasoning,
)
elif strategy == "zs-cot":
return map_formatter_zs_cot(multimodal_data, user_instruction)
@@ -228,21 +253,29 @@ def map_formatter(
for ex_df_txt, ex_ans in zip(examples_multimodal_data, examples_answer):
messages.extend(
[
- user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"),
+ user_message_formatter(
+ ex_df_txt, f"Instruction: {user_instruction}"
+ ),
{"role": "assistant", "content": str(ex_ans)},
]
)
- messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"))
+ messages.append(
+ user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")
+ )
return messages
def extract_formatter(
- multimodal_data: dict[str, Any], output_cols: dict[str, str | None], extract_quotes: bool = True
+ multimodal_data: dict[str, Any],
+ output_cols: dict[str, str | None],
+ extract_quotes: bool = True,
) -> list[dict[str, str]]:
output_col_names = list(output_cols.keys())
# Set the description to be the key if no value is provided
- output_cols_with_desc: dict[str, str] = {col: col if desc is None else desc for col, desc in output_cols.items()}
+ output_cols_with_desc: dict[str, str] = {
+ col: col if desc is None else desc for col, desc in output_cols.items()
+ }
all_fields = output_col_names
if extract_quotes:
@@ -271,10 +304,14 @@ def df2text(df: pd.DataFrame, cols: list[str]) -> list[str]:
"""Formats the given DataFrame into a string containing info from cols."""
def custom_format_row(x: pd.Series, cols: list[str]) -> str:
- return "".join([f"[{cols[i].capitalize()}]: «{x[cols[i]]}»\n" for i in range(len(cols))])
+ return "".join(
+ [f"[{cols[i].capitalize()}]: «{x[cols[i]]}»\n" for i in range(len(cols))]
+ )
def clean_and_escape_column_name(column_name: str) -> str:
- clean_name = re.sub(r"[^\w]", "", column_name) # Remove spaces and special characters
+ clean_name = re.sub(
+ r"[^\w]", "", column_name
+ ) # Remove spaces and special characters
return clean_name
# take cols that are in df
@@ -286,7 +323,9 @@ def clean_and_escape_column_name(column_name: str) -> str:
formatted_rows: list[str] = []
if lotus.settings.serialization_format == SerializationFormat.DEFAULT:
- formatted_rows = projected_df.apply(lambda x: custom_format_row(x, cols), axis=1).tolist()
+ formatted_rows = projected_df.apply(
+ lambda x: custom_format_row(x, cols), axis=1
+ ).tolist()
elif lotus.settings.serialization_format == SerializationFormat.JSON:
formatted_rows = projected_df.to_json(orient="records", lines=True).splitlines()
elif lotus.settings.serialization_format == SerializationFormat.XML:
@@ -298,33 +337,71 @@ def clean_and_escape_column_name(column_name: str) -> str:
"You can install it with the following command:\n\n"
" pip install 'lotus-ai[xml]'"
)
- projected_df = projected_df.rename(columns=lambda x: clean_and_escape_column_name(x))
- full_xml = projected_df.to_xml(root_name="data", row_name="row", pretty_print=False, index=False)
+ projected_df = projected_df.rename(
+ columns=lambda x: clean_and_escape_column_name(x)
+ )
+ full_xml = projected_df.to_xml(
+ root_name="data", row_name="row", pretty_print=False, index=False
+ )
root = ET.fromstring(full_xml)
- formatted_rows = [ET.tostring(row, encoding="unicode", method="xml") for row in root.findall("row")]
+ formatted_rows = [
+ ET.tostring(row, encoding="unicode", method="xml")
+ for row in root.findall("row")
+ ]
return formatted_rows
-def df2multimodal_info(df: pd.DataFrame, cols: list[str]) -> list[dict[str, Any]]:
+def pdf2text(file_path: str) -> str:
+ try:
+ with fitz.open(file_path) as doc:
+ return " ".join(page.get_text() for page in doc)
+ except Exception as e:
+ lotus.logger.debug(f"Error while processing pdf at file path {file_path}: {e}")
+ return ""
+
+
+def df2multimodal_info(
+ df: pd.DataFrame, cols: list[str], infer_pdfs: bool = False
+) -> list[dict[str, Any]]:
"""
Formats the given DataFrame into a string containing info from cols.
Return a list of dictionaries, each containing text and image data.
"""
+
+ # We want to modify `df` to replace PDF paths with their text, but without modifying the source
+ df_copy = df.copy()
+
image_cols = [col for col in cols if isinstance(df[col].dtype, ImageDtype)]
+ if infer_pdfs:
+ pdf_cols = [
+ col
+ for col in cols
+ if col not in image_cols
+ and df[col].apply(lambda x: isinstance(x, str) and x.endswith(".pdf")).all()
+ ]
+ for col in pdf_cols:
+ df_copy[col] = df_copy[col].apply(pdf2text)
+
text_cols = [col for col in cols if col not in image_cols]
- text_rows = df2text(df, text_cols)
+
+ text_rows = df2text(df_copy, text_cols)
multimodal_data = [
{
"text": text_rows[i],
- "image": {col.capitalize(): df[col].array.get_image(i, "base64") for col in image_cols},
+ "image": {
+ col.capitalize(): df[col].array.get_image(i, "base64")
+ for col in image_cols
+ },
}
for i in range(len(df))
]
return multimodal_data
-def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, Any]]) -> list[dict[str, Any]]:
+def merge_multimodal_info(
+ first: list[dict[str, Any]], second: list[dict[str, Any]]
+) -> list[dict[str, Any]]:
"""
Merges two multimodal info lists into one. Each row of first is merged with each row of second.
@@ -337,9 +414,11 @@ def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, An
"""
return [
{
- "text": f"{first[i]['text']}\n{second[j]['text']}"
- if first[i]["text"] != "" and second[j]["text"] != ""
- else first[i]["text"] + second[j]["text"],
+ "text": (
+ f"{first[i]['text']}\n{second[j]['text']}"
+ if first[i]["text"] != "" and second[j]["text"] != ""
+ else first[i]["text"] + second[j]["text"]
+ ),
"image": {**first[i]["image"], **second[j]["image"]},
}
for i in range(len(first))
diff --git a/requirements.txt b/requirements.txt
index 226370bc..901350aa 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,4 +5,5 @@ numpy==1.26.4
pandas==2.2.2
sentence-transformers==3.0.1
tiktoken==0.7.0
-tqdm==4.66.4
\ No newline at end of file
+tqdm==4.66.4
+PyMuPDF==1.25.1