Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow custom httpx client #384

Merged
merged 3 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions llama_parse/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import mimetypes
import time
from pathlib import Path
from typing import List, Optional, Union
from typing import AsyncGenerator, List, Optional, Union
from contextlib import asynccontextmanager
from io import BufferedIOBase

from llama_index.core.async_utils import run_jobs
Expand Down Expand Up @@ -141,6 +142,9 @@ class LlamaParse(BasePydanticReader):
default=False,
description="Whether to take screenshot of each page of the document.",
)
custom_client: Optional[httpx.AsyncClient] = Field(
default=None, description="A custom HTTPX client to use for sending requests."
)

@field_validator("api_key", mode="before", check_fields=True)
@classmethod
Expand All @@ -163,6 +167,15 @@ def validate_base_url(cls, v: str) -> str:
url = os.getenv("LLAMA_CLOUD_BASE_URL", None)
return url or v or DEFAULT_BASE_URL

@asynccontextmanager
async def client_context(self) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Create a context for the HTTPX client."""
if self.custom_client is not None:
yield self.custom_client
else:
async with httpx.AsyncClient(timeout=self.max_timeout) as client:
yield client

# upload a document and get back a job_id
async def _create_job(
self, file_input: FileInput, extra_info: Optional[dict] = None
Expand Down Expand Up @@ -231,7 +244,7 @@ async def _create_job(
data["target_pages"] = self.target_pages

try:
async with httpx.AsyncClient(timeout=self.max_timeout) as client:
async with self.client_context() as client:
response = await client.post(
url,
files=files,
Expand All @@ -257,7 +270,7 @@ async def _get_job_result(
tries = 0
while True:
await asyncio.sleep(self.check_interval)
async with httpx.AsyncClient(timeout=self.max_timeout) as client:
async with self.client_context() as client:
tries += 1

result = await client.get(status_url, headers=headers)
Expand Down Expand Up @@ -447,7 +460,9 @@ def get_json_result(
else:
raise e

def get_images(self, json_result: List[dict], download_path: str) -> List[dict]:
async def aget_images(
self, json_result: List[dict], download_path: str
) -> List[dict]:
"""Download images from the parsed result."""
headers = {"Authorization": f"Bearer {self.api_key}"}

Expand Down Expand Up @@ -481,11 +496,12 @@ def get_images(self, json_result: List[dict], download_path: str) -> List[dict]:
image["page_number"] = page["page"]
with open(image_path, "wb") as f:
image_url = f"{self.base_url}/api/parsing/job/{job_id}/result/image/{image_name}"
f.write(
httpx.get(
async with self.client_context() as client:
res = await client.get(
image_url, headers=headers, timeout=self.max_timeout
).content
)
)
res.raise_for_status()
f.write(res.content)
images.append(image)
return images
except Exception as e:
Expand All @@ -495,6 +511,16 @@ def get_images(self, json_result: List[dict], download_path: str) -> List[dict]:
else:
raise e

def get_images(self, json_result: List[dict], download_path: str) -> List[dict]:
"""Download images from the parsed result."""
try:
return asyncio.run(self.aget_images(json_result, download_path))
except RuntimeError as e:
if nest_asyncio_err in str(e):
raise RuntimeError(nest_asyncio_msg)
else:
raise e

def _get_sub_docs(self, docs: List[Document]) -> List[Document]:
"""Split docs into pages, by separator."""
sub_docs = []
Expand Down
16 changes: 16 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pytest
from httpx import AsyncClient
from llama_parse import LlamaParse


Expand Down Expand Up @@ -93,3 +94,18 @@ def test_simple_page_progress_workers() -> None:
result = parser.load_data([filepath, filepath])
assert len(result) == 2
assert len(result[0].text) > 0


@pytest.mark.skipif(
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
def test_custom_client() -> None:
custom_client = AsyncClient(verify=False, timeout=10)
parser = LlamaParse(result_type="markdown", custom_client=custom_client)
filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
result = parser.load_data(filepath)
assert len(result) == 1
assert len(result[0].text) > 0
Loading