diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index ed9619e..b0c2d3e 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -6,7 +6,11 @@ from lotus.cache import operator_cache from lotus.models import LM from lotus.templates import task_instructions -from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput +from lotus.types import ( + LMOutput, + SemanticExtractOutput, + SemanticExtractPostprocessOutput, +) from lotus.utils import show_safe_mode from .postprocessors import extract_postprocess @@ -17,7 +21,9 @@ def sem_extract( model: LM, output_cols: dict[str, str | None], extract_quotes: bool = False, - postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess, + postprocessor: Callable[ + [list[str]], SemanticExtractPostprocessOutput + ] = extract_postprocess, safe_mode: bool = False, progress_bar_desc: str = "Extracting", ) -> SemanticExtractOutput: @@ -39,7 +45,9 @@ def sem_extract( for doc in docs: prompt = task_instructions.extract_formatter(doc, output_cols, extract_quotes) lotus.logger.debug(f"input to model: {prompt}") - lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}") + lotus.logger.debug( + f"inputs content to model: {[x.get('content') for x in prompt]}" + ) inputs.append(prompt) # check if safe_mode is enabled @@ -49,7 +57,11 @@ def sem_extract( show_safe_mode(estimated_cost, estimated_LM_calls) # call model - lm_output: LMOutput = model(inputs, response_format={"type": "json_object"}, progress_bar_desc=progress_bar_desc) + lm_output: LMOutput = model( + inputs, + response_format={"type": "json_object"}, + progress_bar_desc=progress_bar_desc, + ) # post process results postprocess_output = postprocessor(lm_output.outputs) @@ -78,8 +90,11 @@ def __call__( input_cols: list[str], output_cols: dict[str, str | None], extract_quotes: bool = False, - postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess, + postprocessor: Callable[ + [list[str]], SemanticExtractPostprocessOutput + ] = extract_postprocess, return_raw_outputs: bool = False, + infer_pdfs: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Extracting", ) -> pd.DataFrame: @@ -106,7 +121,9 @@ def __call__( if column not in self._obj.columns: raise ValueError(f"Column {column} not found in DataFrame") - multimodal_data = task_instructions.df2multimodal_info(self._obj, input_cols) + multimodal_data = task_instructions.df2multimodal_info( + self._obj, input_cols, infer_pdfs + ) out = sem_extract( docs=multimodal_data, diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index fc30efd..4030f14 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -1,6 +1,7 @@ import re from typing import Any +import fitz import pandas as pd import lotus @@ -77,7 +78,9 @@ def filter_formatter_cot( cot = cot_reasoning[idx] messages.extend( [ - user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), + user_message_formatter( + ex_multimodal_data, f"Claim: {user_instruction}" + ), { "role": "assistant", "content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}", @@ -85,7 +88,9 @@ def filter_formatter_cot( ] ) - messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) + messages.append( + user_message_formatter(multimodal_data, f"Claim: {user_instruction}") + ) return messages @@ -102,7 +107,9 @@ def filter_formatter_zs_cot( {"role": "system", "content": sys_instruction}, ] - messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) + messages.append( + user_message_formatter(multimodal_data, f"Claim: {user_instruction}") + ) return messages @@ -117,7 +124,11 @@ def filter_formatter( if cot_reasoning: assert examples_multimodal_data is not None and examples_answer is not None return filter_formatter_cot( - multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning + multimodal_data, + user_instruction, + examples_multimodal_data, + examples_answer, + cot_reasoning, ) elif strategy == "zs-cot": return filter_formatter_zs_cot(multimodal_data, user_instruction) @@ -133,18 +144,24 @@ def filter_formatter( if examples_multimodal_data: assert examples_answer is not None - assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list) + assert isinstance(examples_multimodal_data, list) and isinstance( + examples_answer, list + ) for i in range(len(examples_multimodal_data)): ex_multimodal_data = examples_multimodal_data[i] ex_ans = examples_answer[i] messages.extend( [ - user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), + user_message_formatter( + ex_multimodal_data, f"Claim: {user_instruction}" + ), {"role": "assistant", "content": str(ex_ans)}, ] ) - messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) + messages.append( + user_message_formatter(multimodal_data, f"Claim: {user_instruction}") + ) return messages @@ -178,7 +195,9 @@ def map_formatter_cot( ] ) - messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) + messages.append( + user_message_formatter(multimodal_data, f"Instruction: {user_instruction}") + ) return messages @@ -195,7 +214,9 @@ def map_formatter_zs_cot( {"role": "system", "content": sys_instruction}, ] - messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) + messages.append( + user_message_formatter(multimodal_data, f"Instruction: {user_instruction}") + ) return messages @@ -210,7 +231,11 @@ def map_formatter( if cot_reasoning: assert examples_multimodal_data is not None and examples_answer is not None return map_formatter_cot( - multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning + multimodal_data, + user_instruction, + examples_multimodal_data, + examples_answer, + cot_reasoning, ) elif strategy == "zs-cot": return map_formatter_zs_cot(multimodal_data, user_instruction) @@ -228,21 +253,29 @@ def map_formatter( for ex_df_txt, ex_ans in zip(examples_multimodal_data, examples_answer): messages.extend( [ - user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"), + user_message_formatter( + ex_df_txt, f"Instruction: {user_instruction}" + ), {"role": "assistant", "content": str(ex_ans)}, ] ) - messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) + messages.append( + user_message_formatter(multimodal_data, f"Instruction: {user_instruction}") + ) return messages def extract_formatter( - multimodal_data: dict[str, Any], output_cols: dict[str, str | None], extract_quotes: bool = True + multimodal_data: dict[str, Any], + output_cols: dict[str, str | None], + extract_quotes: bool = True, ) -> list[dict[str, str]]: output_col_names = list(output_cols.keys()) # Set the description to be the key if no value is provided - output_cols_with_desc: dict[str, str] = {col: col if desc is None else desc for col, desc in output_cols.items()} + output_cols_with_desc: dict[str, str] = { + col: col if desc is None else desc for col, desc in output_cols.items() + } all_fields = output_col_names if extract_quotes: @@ -271,10 +304,14 @@ def df2text(df: pd.DataFrame, cols: list[str]) -> list[str]: """Formats the given DataFrame into a string containing info from cols.""" def custom_format_row(x: pd.Series, cols: list[str]) -> str: - return "".join([f"[{cols[i].capitalize()}]: «{x[cols[i]]}»\n" for i in range(len(cols))]) + return "".join( + [f"[{cols[i].capitalize()}]: «{x[cols[i]]}»\n" for i in range(len(cols))] + ) def clean_and_escape_column_name(column_name: str) -> str: - clean_name = re.sub(r"[^\w]", "", column_name) # Remove spaces and special characters + clean_name = re.sub( + r"[^\w]", "", column_name + ) # Remove spaces and special characters return clean_name # take cols that are in df @@ -286,7 +323,9 @@ def clean_and_escape_column_name(column_name: str) -> str: formatted_rows: list[str] = [] if lotus.settings.serialization_format == SerializationFormat.DEFAULT: - formatted_rows = projected_df.apply(lambda x: custom_format_row(x, cols), axis=1).tolist() + formatted_rows = projected_df.apply( + lambda x: custom_format_row(x, cols), axis=1 + ).tolist() elif lotus.settings.serialization_format == SerializationFormat.JSON: formatted_rows = projected_df.to_json(orient="records", lines=True).splitlines() elif lotus.settings.serialization_format == SerializationFormat.XML: @@ -298,33 +337,71 @@ def clean_and_escape_column_name(column_name: str) -> str: "You can install it with the following command:\n\n" " pip install 'lotus-ai[xml]'" ) - projected_df = projected_df.rename(columns=lambda x: clean_and_escape_column_name(x)) - full_xml = projected_df.to_xml(root_name="data", row_name="row", pretty_print=False, index=False) + projected_df = projected_df.rename( + columns=lambda x: clean_and_escape_column_name(x) + ) + full_xml = projected_df.to_xml( + root_name="data", row_name="row", pretty_print=False, index=False + ) root = ET.fromstring(full_xml) - formatted_rows = [ET.tostring(row, encoding="unicode", method="xml") for row in root.findall("row")] + formatted_rows = [ + ET.tostring(row, encoding="unicode", method="xml") + for row in root.findall("row") + ] return formatted_rows -def df2multimodal_info(df: pd.DataFrame, cols: list[str]) -> list[dict[str, Any]]: +def pdf2text(file_path: str) -> str: + try: + with fitz.open(file_path) as doc: + return " ".join(page.get_text() for page in doc) + except Exception as e: + lotus.logger.debug(f"Error while processing pdf at file path {file_path}: {e}") + return "" + + +def df2multimodal_info( + df: pd.DataFrame, cols: list[str], infer_pdfs: bool = False +) -> list[dict[str, Any]]: """ Formats the given DataFrame into a string containing info from cols. Return a list of dictionaries, each containing text and image data. """ + + # We want to modify `df` to replace PDF paths with their text, but without modifying the source + df_copy = df.copy() + image_cols = [col for col in cols if isinstance(df[col].dtype, ImageDtype)] + if infer_pdfs: + pdf_cols = [ + col + for col in cols + if col not in image_cols + and df[col].apply(lambda x: isinstance(x, str) and x.endswith(".pdf")).all() + ] + for col in pdf_cols: + df_copy[col] = df_copy[col].apply(pdf2text) + text_cols = [col for col in cols if col not in image_cols] - text_rows = df2text(df, text_cols) + + text_rows = df2text(df_copy, text_cols) multimodal_data = [ { "text": text_rows[i], - "image": {col.capitalize(): df[col].array.get_image(i, "base64") for col in image_cols}, + "image": { + col.capitalize(): df[col].array.get_image(i, "base64") + for col in image_cols + }, } for i in range(len(df)) ] return multimodal_data -def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, Any]]) -> list[dict[str, Any]]: +def merge_multimodal_info( + first: list[dict[str, Any]], second: list[dict[str, Any]] +) -> list[dict[str, Any]]: """ Merges two multimodal info lists into one. Each row of first is merged with each row of second. @@ -337,9 +414,11 @@ def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, An """ return [ { - "text": f"{first[i]['text']}\n{second[j]['text']}" - if first[i]["text"] != "" and second[j]["text"] != "" - else first[i]["text"] + second[j]["text"], + "text": ( + f"{first[i]['text']}\n{second[j]['text']}" + if first[i]["text"] != "" and second[j]["text"] != "" + else first[i]["text"] + second[j]["text"] + ), "image": {**first[i]["image"], **second[j]["image"]}, } for i in range(len(first)) diff --git a/requirements.txt b/requirements.txt index 226370b..901350a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ numpy==1.26.4 pandas==2.2.2 sentence-transformers==3.0.1 tiktoken==0.7.0 -tqdm==4.66.4 \ No newline at end of file +tqdm==4.66.4 +PyMuPDF==1.25.1