diff --git a/lotus/dtype_extensions/__init__.py b/lotus/dtype_extensions/__init__.py index b836bdb5..6a055e7e 100644 --- a/lotus/dtype_extensions/__init__.py +++ b/lotus/dtype_extensions/__init__.py @@ -1,7 +1,10 @@ from lotus.dtype_extensions.image import ImageDtype, ImageArray +from lotus.dtype_extensions.document import DocumentDtype, DocumentArray + import pandas as pd pd.api.extensions.register_extension_dtype(ImageDtype) +pd.api.extensions.register_extension_dtype(DocumentDtype) def convert_to_base_data(data: pd.Series | list) -> list: @@ -13,9 +16,11 @@ def convert_to_base_data(data: pd.Series | list) -> list: if isinstance(data, pd.Series): if isinstance(data.dtype, ImageDtype): return [data.array.get_image(i) for i in range(len(data))] + elif isinstance(data.dtype, DocumentDtype): + return [data.array.get_document(i) for i in range(len(data))] return data.tolist() return data -__all__ = ["ImageDtype", "ImageArray", "convert_to_base_data"] +__all__ = ["ImageDtype", "ImageArray", "DocumentDtype", "DocumentArray", "convert_to_base_data"] diff --git a/lotus/dtype_extensions/document.py b/lotus/dtype_extensions/document.py new file mode 100644 index 00000000..f2531d3c --- /dev/null +++ b/lotus/dtype_extensions/document.py @@ -0,0 +1,174 @@ +import sys +from typing import Sequence, Union + +import numpy as np +import pandas as pd +import pymupdf +from pandas.api.extensions import ExtensionArray, ExtensionDtype +from pymupdf import Document + +from lotus.utils import fetch_document + + +class DocumentDtype(ExtensionDtype): + name = "document" + type = Document + na_value = None + + @classmethod + def construct_array_type(cls): + return DocumentArray + + +class DocumentArray(ExtensionArray): + def __init__(self, values): + self._data = np.asarray(values, dtype=object) + self._dtype = DocumentDtype() + self.allowed_document_types = ["Document", "string"] + self._cached_documents: dict[tuple[int, str], str | Document | None] = {} # Cache for loaded documents + + def __getitem__(self, item: int | slice | Sequence[int]) -> np.ndarray: + result = self._data[item] + + if isinstance(item, (int, np.integer)): + # Return the raw value for display purposes + return result + + return DocumentArray(result) + + def __setitem__(self, key, value) -> None: + """Set one or more values inplace, with cache invalidation.""" + if isinstance(key, np.ndarray): + if key.dtype == bool: + key = np.where(key)[0] + key = key.tolist() + if isinstance(key, (int, np.integer)): + key = [key] + if isinstance(key, slice): + key = range(*key.indices(len(self))) + if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): + for idx, val in zip(key, value): + self._data[idx] = val + self._invalidate_cache(idx) + else: + for idx in key: + self._data[idx] = value + self._invalidate_cache(idx) + + def _invalidate_cache(self, idx: int) -> None: + """Remove an item from the cache.""" + for doc_type in self.allowed_document_types: + if (idx, doc_type) in self._cached_documents: + del self._cached_documents[(idx, doc_type)] + + def get_document(self, idx: int, doc_type: str = "Document") -> Union[Document, str, None]: + """Explicit method to fetch and return the actual document""" + if (idx, doc_type) not in self._cached_documents: + document_result = fetch_document(self._data[idx], doc_type) + assert document_result is None or isinstance(document_result, (Document, str)) + self._cached_documents[(idx, doc_type)] = document_result + return self._cached_documents[(idx, doc_type)] + + def isna(self) -> np.ndarray: + return pd.isna(self._data) + + def take(self, indices: Sequence[int], allow_fill: bool = False, fill_value=None) -> "DocumentArray": + result = self._data.take(indices, axis=0) + if allow_fill and fill_value is not None: + result[indices == -1] = fill_value + return DocumentArray(result) + + def copy(self) -> "DocumentArray": + new_array = DocumentArray(self._data.copy()) + new_array._cached_documents = self._cached_documents.copy() + return new_array + + def _concat_same_type(cls, to_concat: Sequence["DocumentArray"]) -> "DocumentArray": + """ + Concatenate multiple DocumentArray instances into a single one. + + Args: + to_concat (Sequence[DocumentArray]): A sequence of DocumentArray instances to concatenate. + + Returns: + DocumentArray: A new DocumentArray containing all elements from the input arrays. + """ + combined_data = np.concatenate([arr._data for arr in to_concat]) + return cls._from_sequence(combined_data) + + @classmethod + def _from_sequence(cls, scalars, dtype=None, copy=False): + if copy: + scalars = np.array(scalars, dtype=object, copy=True) + return cls(scalars) + + def __len__(self) -> int: + return len(self._data) + + def __eq__(self, other) -> np.ndarray: # type: ignore + if isinstance(other, DocumentArray): + return np.array( + [_compare_documents(doc1, doc2) for doc1, doc2 in zip(self._data, other._data)], + dtype=bool, + ) + + if hasattr(other, "__iter__") and not isinstance(other, str): + if len(other) != len(self): + return np.repeat(False, len(self)) + return np.array( + [_compare_documents(doc1, doc2) for doc1, doc2 in zip(self._data, other)], + dtype=bool, + ) + return np.array([_compare_documents(doc, other) for doc in self._data], dtype=bool) + + @property + def dtype(self) -> DocumentDtype: + return self._dtype + + @property + def nbytes(self) -> int: + return sum(sys.getsizeof(doc) for doc in self._data if doc) + + def __repr__(self) -> str: + return f"DocumentArray([{', '.join([f'' if isinstance(doc, Document) else f'' for doc in self._data[:5]])}, ...])" + + def _formatter(self, boxed: bool = False): + return lambda x: (f"" if isinstance(x, Document) else f" np.ndarray: + """Convert the DocumentArray to a numpy array.""" + documents = [] + for i, doc_data in enumerate(self._data): + if isinstance(doc_data, Document): + text = doc_data.metadata.__str__() + "\n" + for page in doc_data.pages: + text += page.get_text() + "\n" + documents.append(text) + elif isinstance(doc_data, str): + doc = pymupdf.open(doc_data) + text = doc.metadata.__str__() + "\n" + for page in doc: + text += page.get_text() + "\n" + documents.append(text) + result = np.empty(len(self), dtype=object) + result[:] = documents + return result + + def __array__(self, dtype=None) -> np.ndarray: + """Numpy array interface.""" + return self.to_numpy(dtype=dtype) + + +def _compare_documents(doc1, doc2) -> bool: + if doc1 is None or doc2 is None: + return doc1 is doc2 + + # Only fetch documents when actually comparing + if isinstance(doc1, Document) and isinstance(doc2, Document): + if doc1.page_count == doc2.page_count: + for doc1_page, doc2_page in zip(doc1.pages, doc2.pages): + if doc1_page.extract_text() != doc2_page.extract_text(): + return False + return doc1.metadata == doc2.metadata + else: + return doc1 == doc2 diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index fc30efd9..06bedcc4 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -4,7 +4,7 @@ import pandas as pd import lotus -from lotus.dtype_extensions import ImageDtype +from lotus.dtype_extensions import DocumentDtype, ImageDtype from lotus.types import SerializationFormat @@ -312,11 +312,15 @@ def df2multimodal_info(df: pd.DataFrame, cols: list[str]) -> list[dict[str, Any] Return a list of dictionaries, each containing text and image data. """ image_cols = [col for col in cols if isinstance(df[col].dtype, ImageDtype)] - text_cols = [col for col in cols if col not in image_cols] + document_cols = [col for col in cols if isinstance(df[col].dtype, DocumentDtype)] + text_cols = [col for col in cols if col not in image_cols and col not in document_cols] text_rows = df2text(df, text_cols) multimodal_data = [ { - "text": text_rows[i], + "text": { + "text": text_rows[i], + **{col.capitalize(): df[col].array.get_document(i, "string") for col in document_cols}, + }, "image": {col.capitalize(): df[col].array.get_image(i, "base64") for col in image_cols}, } for i in range(len(df)) diff --git a/lotus/utils.py b/lotus/utils.py index b3e68ac9..9b2d4941 100644 --- a/lotus/utils.py +++ b/lotus/utils.py @@ -7,6 +7,7 @@ import pandas as pd import requests # type: ignore from PIL import Image +from pymupdf import Document import lotus @@ -29,8 +30,6 @@ def ret( verbose: bool = False, method: str = "kmeans", ) -> list[int]: - - import faiss """Cluster by column, and return a series in the dataframe with cluster-ids""" @@ -64,7 +63,7 @@ def ret( # get nearest centroid to each vector scores, indices = kmeans.index.search(vec_set, 1) - + # get the cluster centroids # centroids = kmeans.centroids # return indices.flatten(), scores.flatten(), centroids @@ -73,6 +72,24 @@ def ret( return ret +def fetch_document(document: Document | str, doc_type: str = "Document") -> Document | str | None: + if document is None: + return None + + document_obj = None + if isinstance(document, Document): + document_obj = document + elif isinstance(document, str): + document_obj = Document(document) + + if doc_type == "string": + res = document_obj.metadata.__str__() + "\n" + for page in document_obj.pages: + res += page.extract_text() + "\n" + return res + return document_obj + + def fetch_image(image: str | np.ndarray | Image.Image | None, image_type: str = "Image") -> Image.Image | str | None: if image is None: return None diff --git a/requirements.txt b/requirements.txt index e645c716..850eb2e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ tqdm==4.66.4 weaviate-client==4.10.2 pinecone==5.4.2 chromadb==0.6.2 -qdrant-client==1.12.2 \ No newline at end of file +qdrant-client==1.12.2 +PyMuPDF==1.25.2 \ No newline at end of file diff --git a/tests/assets/cs70-propositional-logic.pdf b/tests/assets/cs70-propositional-logic.pdf new file mode 100644 index 00000000..50761be3 Binary files /dev/null and b/tests/assets/cs70-propositional-logic.pdf differ diff --git a/tests/test_serialization.py b/tests/test_serialization.py index a3f9e349..7585032c 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -2,14 +2,20 @@ import pytest import lotus +from lotus.dtype_extensions import DocumentArray, DocumentDtype from lotus.settings import SerializationFormat from lotus.templates.task_instructions import df2text from tests.base_test import BaseTest @pytest.fixture -def sample_df(): - return pd.DataFrame({"Name": ["Alice", "Bob"], "Age": [25, 30], "City": ["New York", "London"]}) +def sample_df_with_pdfs(): + return pd.DataFrame( + { + "Title": ["Propositional Logic"], + "Pdf": DocumentArray(["tests/assets/cs70-propositional-logic.pdf"]), + } + ) @pytest.fixture(autouse=True) @@ -34,7 +40,10 @@ def test_df2text_xml_format(self, sample_df): lotus.settings.serialization_format = SerializationFormat.XML result = df2text(sample_df, ["Name", "Age"]) print(result) - expected = ["Alice25", "Bob30"] + expected = [ + "Alice25", + "Bob30", + ] assert result == expected def test_df2text_nonexistent_columns(self, sample_df): @@ -53,3 +62,7 @@ def test_df2text_all_columns(self, sample_df): "[Name]: «Bob»\n[Age]: «30»\n[City]: «London»\n", ] assert result == expected + + def test_assert_documentdtype_recognition(self, sample_df_with_pdfs): + print(sample_df_with_pdfs) + assert isinstance(sample_df_with_pdfs["Pdf"].dtype, DocumentDtype)