Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Document DType for document parsing/management #95

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lotus/dtype_extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from lotus.dtype_extensions.image import ImageDtype, ImageArray
from lotus.dtype_extensions.document import DocumentDtype, DocumentArray

import pandas as pd

pd.api.extensions.register_extension_dtype(ImageDtype)
pd.api.extensions.register_extension_dtype(DocumentDtype)


def convert_to_base_data(data: pd.Series | list) -> list:
Expand All @@ -13,9 +16,11 @@ def convert_to_base_data(data: pd.Series | list) -> list:
if isinstance(data, pd.Series):
if isinstance(data.dtype, ImageDtype):
return [data.array.get_image(i) for i in range(len(data))]
elif isinstance(data.dtype, DocumentDtype):
return [data.array.get_document(i) for i in range(len(data))]
return data.tolist()

return data


__all__ = ["ImageDtype", "ImageArray", "convert_to_base_data"]
__all__ = ["ImageDtype", "ImageArray", "DocumentDtype", "DocumentArray", "convert_to_base_data"]
174 changes: 174 additions & 0 deletions lotus/dtype_extensions/document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import sys
from typing import Sequence, Union

import numpy as np
import pandas as pd
import pymupdf
from pandas.api.extensions import ExtensionArray, ExtensionDtype
from pymupdf import Document

from lotus.utils import fetch_document


class DocumentDtype(ExtensionDtype):
name = "document"
type = Document
na_value = None

@classmethod
def construct_array_type(cls):
return DocumentArray


class DocumentArray(ExtensionArray):
def __init__(self, values):
self._data = np.asarray(values, dtype=object)
self._dtype = DocumentDtype()
self.allowed_document_types = ["Document", "string"]
self._cached_documents: dict[tuple[int, str], str | Document | None] = {} # Cache for loaded documents

def __getitem__(self, item: int | slice | Sequence[int]) -> np.ndarray:
result = self._data[item]

if isinstance(item, (int, np.integer)):
# Return the raw value for display purposes
return result

return DocumentArray(result)

def __setitem__(self, key, value) -> None:
"""Set one or more values inplace, with cache invalidation."""
if isinstance(key, np.ndarray):
if key.dtype == bool:
key = np.where(key)[0]
key = key.tolist()
if isinstance(key, (int, np.integer)):
key = [key]
if isinstance(key, slice):
key = range(*key.indices(len(self)))
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
for idx, val in zip(key, value):
self._data[idx] = val
self._invalidate_cache(idx)
else:
for idx in key:
self._data[idx] = value
self._invalidate_cache(idx)

def _invalidate_cache(self, idx: int) -> None:
"""Remove an item from the cache."""
for doc_type in self.allowed_document_types:
if (idx, doc_type) in self._cached_documents:
del self._cached_documents[(idx, doc_type)]

def get_document(self, idx: int, doc_type: str = "Document") -> Union[Document, str, None]:
"""Explicit method to fetch and return the actual document"""
if (idx, doc_type) not in self._cached_documents:
document_result = fetch_document(self._data[idx], doc_type)
assert document_result is None or isinstance(document_result, (Document, str))
self._cached_documents[(idx, doc_type)] = document_result
return self._cached_documents[(idx, doc_type)]

def isna(self) -> np.ndarray:
return pd.isna(self._data)

def take(self, indices: Sequence[int], allow_fill: bool = False, fill_value=None) -> "DocumentArray":
result = self._data.take(indices, axis=0)
if allow_fill and fill_value is not None:
result[indices == -1] = fill_value
return DocumentArray(result)

def copy(self) -> "DocumentArray":
new_array = DocumentArray(self._data.copy())
new_array._cached_documents = self._cached_documents.copy()
return new_array

def _concat_same_type(cls, to_concat: Sequence["DocumentArray"]) -> "DocumentArray":
"""
Concatenate multiple DocumentArray instances into a single one.

Args:
to_concat (Sequence[DocumentArray]): A sequence of DocumentArray instances to concatenate.

Returns:
DocumentArray: A new DocumentArray containing all elements from the input arrays.
"""
combined_data = np.concatenate([arr._data for arr in to_concat])
return cls._from_sequence(combined_data)

@classmethod
def _from_sequence(cls, scalars, dtype=None, copy=False):
if copy:
scalars = np.array(scalars, dtype=object, copy=True)
return cls(scalars)

def __len__(self) -> int:
return len(self._data)

def __eq__(self, other) -> np.ndarray: # type: ignore
if isinstance(other, DocumentArray):
return np.array(
[_compare_documents(doc1, doc2) for doc1, doc2 in zip(self._data, other._data)],
dtype=bool,
)

if hasattr(other, "__iter__") and not isinstance(other, str):
if len(other) != len(self):
return np.repeat(False, len(self))
return np.array(
[_compare_documents(doc1, doc2) for doc1, doc2 in zip(self._data, other)],
dtype=bool,
)
return np.array([_compare_documents(doc, other) for doc in self._data], dtype=bool)

@property
def dtype(self) -> DocumentDtype:
return self._dtype

@property
def nbytes(self) -> int:
return sum(sys.getsizeof(doc) for doc in self._data if doc)

def __repr__(self) -> str:
return f"DocumentArray([{', '.join([f'<Document: {doc.metadata}>' if isinstance(doc, Document) else f'<Document: {doc}>' for doc in self._data[:5]])}, ...])"

def _formatter(self, boxed: bool = False):
return lambda x: (f"<Document: {x.metadata}>" if isinstance(x, Document) else f"<Document: {x}")

def to_numpy(self, dtype=None, copy=False, na_value=None) -> np.ndarray:
"""Convert the DocumentArray to a numpy array."""
documents = []
for i, doc_data in enumerate(self._data):
if isinstance(doc_data, Document):
text = doc_data.metadata.__str__() + "\n"
for page in doc_data.pages:
text += page.get_text() + "\n"
documents.append(text)
elif isinstance(doc_data, str):
doc = pymupdf.open(doc_data)
text = doc.metadata.__str__() + "\n"
for page in doc:
text += page.get_text() + "\n"
documents.append(text)
result = np.empty(len(self), dtype=object)
result[:] = documents
return result

def __array__(self, dtype=None) -> np.ndarray:
"""Numpy array interface."""
return self.to_numpy(dtype=dtype)


def _compare_documents(doc1, doc2) -> bool:
if doc1 is None or doc2 is None:
return doc1 is doc2

# Only fetch documents when actually comparing
if isinstance(doc1, Document) and isinstance(doc2, Document):
if doc1.page_count == doc2.page_count:
for doc1_page, doc2_page in zip(doc1.pages, doc2.pages):
if doc1_page.extract_text() != doc2_page.extract_text():
return False
return doc1.metadata == doc2.metadata
else:
return doc1 == doc2
10 changes: 7 additions & 3 deletions lotus/templates/task_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd

import lotus
from lotus.dtype_extensions import ImageDtype
from lotus.dtype_extensions import DocumentDtype, ImageDtype
from lotus.types import SerializationFormat


Expand Down Expand Up @@ -312,11 +312,15 @@ def df2multimodal_info(df: pd.DataFrame, cols: list[str]) -> list[dict[str, Any]
Return a list of dictionaries, each containing text and image data.
"""
image_cols = [col for col in cols if isinstance(df[col].dtype, ImageDtype)]
text_cols = [col for col in cols if col not in image_cols]
document_cols = [col for col in cols if isinstance(df[col].dtype, DocumentDtype)]
text_cols = [col for col in cols if col not in image_cols and col not in document_cols]
text_rows = df2text(df, text_cols)
multimodal_data = [
{
"text": text_rows[i],
"text": {
"text": text_rows[i],
**{col.capitalize(): df[col].array.get_document(i, "string") for col in document_cols},
},
"image": {col.capitalize(): df[col].array.get_image(i, "base64") for col in image_cols},
}
for i in range(len(df))
Expand Down
23 changes: 20 additions & 3 deletions lotus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import requests # type: ignore
from PIL import Image
from pymupdf import Document

import lotus

Expand All @@ -29,8 +30,6 @@ def ret(
verbose: bool = False,
method: str = "kmeans",
) -> list[int]:


import faiss

"""Cluster by column, and return a series in the dataframe with cluster-ids"""
Expand Down Expand Up @@ -64,7 +63,7 @@ def ret(

# get nearest centroid to each vector
scores, indices = kmeans.index.search(vec_set, 1)

# get the cluster centroids
# centroids = kmeans.centroids
# return indices.flatten(), scores.flatten(), centroids
Expand All @@ -73,6 +72,24 @@ def ret(
return ret


def fetch_document(document: Document | str, doc_type: str = "Document") -> Document | str | None:
if document is None:
return None

document_obj = None
if isinstance(document, Document):
document_obj = document
elif isinstance(document, str):
document_obj = Document(document)

if doc_type == "string":
res = document_obj.metadata.__str__() + "\n"
for page in document_obj.pages:
res += page.extract_text() + "\n"
return res
return document_obj


def fetch_image(image: str | np.ndarray | Image.Image | None, image_type: str = "Image") -> Image.Image | str | None:
if image is None:
return None
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ tqdm==4.66.4
weaviate-client==4.10.2
pinecone==5.4.2
chromadb==0.6.2
qdrant-client==1.12.2
qdrant-client==1.12.2
PyMuPDF==1.25.2
Binary file added tests/assets/cs70-propositional-logic.pdf
Binary file not shown.
19 changes: 16 additions & 3 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
import pytest

import lotus
from lotus.dtype_extensions import DocumentArray, DocumentDtype
from lotus.settings import SerializationFormat
from lotus.templates.task_instructions import df2text
from tests.base_test import BaseTest


@pytest.fixture
def sample_df():
return pd.DataFrame({"Name": ["Alice", "Bob"], "Age": [25, 30], "City": ["New York", "London"]})
def sample_df_with_pdfs():
return pd.DataFrame(
{
"Title": ["Propositional Logic"],
"Pdf": DocumentArray(["tests/assets/cs70-propositional-logic.pdf"]),
}
)


@pytest.fixture(autouse=True)
Expand All @@ -34,7 +40,10 @@ def test_df2text_xml_format(self, sample_df):
lotus.settings.serialization_format = SerializationFormat.XML
result = df2text(sample_df, ["Name", "Age"])
print(result)
expected = ["<row><Name>Alice</Name><Age>25</Age></row>", "<row><Name>Bob</Name><Age>30</Age></row>"]
expected = [
"<row><Name>Alice</Name><Age>25</Age></row>",
"<row><Name>Bob</Name><Age>30</Age></row>",
]
assert result == expected

def test_df2text_nonexistent_columns(self, sample_df):
Expand All @@ -53,3 +62,7 @@ def test_df2text_all_columns(self, sample_df):
"[Name]: «Bob»\n[Age]: «30»\n[City]: «London»\n",
]
assert result == expected

def test_assert_documentdtype_recognition(self, sample_df_with_pdfs):
print(sample_df_with_pdfs)
assert isinstance(sample_df_with_pdfs["Pdf"].dtype, DocumentDtype)
Loading