Skip to content

Commit

Permalink
Stoli/feat/connection handling (#568)
Browse files Browse the repository at this point in the history
  • Loading branch information
apostolos-geyer authored Dec 27, 2024
1 parent 6338641 commit 530241d
Showing 1 changed file with 91 additions and 82 deletions.
173 changes: 91 additions & 82 deletions llama_parse/base.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import os
import asyncio
from urllib.parse import urlparse

import httpx
import mimetypes
import os
import time
from pathlib import Path, PurePath, PurePosixPath
from typing import AsyncGenerator, Any, Dict, List, Optional, Union
from contextlib import asynccontextmanager
from copy import deepcopy
from io import BufferedIOBase
from pathlib import Path, PurePath, PurePosixPath
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from urllib.parse import urlparse

import httpx
from fsspec import AbstractFileSystem
from llama_index.core.async_utils import asyncio_run, run_jobs
from llama_index.core.bridge.pydantic import Field, field_validator
from llama_index.core.bridge.pydantic import Field, PrivateAttr, field_validator
from llama_index.core.constants import DEFAULT_BASE_URL
from llama_index.core.readers.base import BasePydanticReader
from llama_index.core.readers.file.base import get_default_fs
from llama_index.core.schema import Document

from llama_parse.utils import (
SUPPORTED_FILE_TYPES,
ResultType,
nest_asyncio_err,
nest_asyncio_msg,
ResultType,
SUPPORTED_FILE_TYPES,
)
from copy import deepcopy

# can put in a path to the file or the file bytes itself
# if passing as bytes or a buffer, must provide the file_name in extra_info
Expand All @@ -32,6 +32,11 @@
_DEFAULT_SEPARATOR = "\n---\n"


JOB_RESULT_URL = "/api/parsing/job/{job_id}/result/{result_type}"
JOB_STATUS_ROUTE = "/api/parsing/job/{job_id}"
JOB_UPLOAD_ROUTE = "/api/parsing/upload"


class LlamaParse(BasePydanticReader):
"""A smart-parser for files."""

Expand All @@ -49,9 +54,11 @@ class LlamaParse(BasePydanticReader):
default=1,
description="The interval in seconds to check if the parsing is done.",
)

custom_client: Optional[httpx.AsyncClient] = Field(
default=None, description="A custom HTTPX client to use for sending requests."
)

ignore_errors: bool = Field(
default=True,
description="Whether or not to ignore and skip errors raised during parsing.",
Expand Down Expand Up @@ -303,6 +310,25 @@ def validate_base_url(cls, v: str) -> str:
url = os.getenv("LLAMA_CLOUD_BASE_URL", None)
return url or v or DEFAULT_BASE_URL

_aclient: Union[httpx.AsyncClient, None] = PrivateAttr(default=None, init=False)

@property
def aclient(self) -> httpx.AsyncClient:
if not self._aclient:
self._aclient = self.custom_client or httpx.AsyncClient()

# need to do this outside instantiation in case user
# updates base_url, api_key, or max_timeout later
# ... you wouldn't usually expect that, except
# if someone does do it and it doesn't reflect on
# the client they'll end up pretty confused, so
# for the sake of ergonomics...
self._aclient.base_url = self.base_url
self._aclient.headers["Authorization"] = f"Bearer {self.api_key}"
self._aclient.timeout = self.max_timeout

return self._aclient

@asynccontextmanager
async def client_context(self) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Create a context for the HTTPX client."""
Expand Down Expand Up @@ -351,8 +377,6 @@ async def _create_job(
extra_info: Optional[dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> str:
headers = {"Authorization": f"Bearer {self.api_key}"}
url = f"{self.base_url}/api/parsing/upload"
files = None
file_handle = None
input_url = file_input if self._is_input_url(file_input) else None
Expand Down Expand Up @@ -570,70 +594,64 @@ async def _create_job(
data["gpt4o_api_key"] = self.gpt4o_api_key

try:
async with self.client_context() as client:
response = await client.post(
url,
files=files,
headers=headers,
data=data,
)
if not response.is_success:
raise Exception(f"Failed to parse the file: {response.text}")
job_id = response.json()["id"]
return job_id
resp = await self.aclient.post(JOB_UPLOAD_ROUTE, files=files, data=data) # type: ignore
resp.raise_for_status() # this raises if status is not 2xx
return resp.json()["id"]
except httpx.HTTPStatusError as err: # this catches it
msg = f"Failed to parse the file: {err.response.text}"
raise Exception(msg) from err # this preserves the exception context
finally:
if file_handle is not None:
file_handle.close()

async def _get_job_result(
self, job_id: str, result_type: str, verbose: bool = False
) -> Dict[str, Any]:
result_url = f"{self.base_url}/api/parsing/job/{job_id}/result/{result_type}"
status_url = f"{self.base_url}/api/parsing/job/{job_id}"
headers = {"Authorization": f"Bearer {self.api_key}"}

start = time.time()
tries = 0

# so we're not re-setting the headers & stuff on each
# usage... assume that there is not some other
# coro also modifying base_url and the other client related configs.
client = self.aclient
while True:
await asyncio.sleep(self.check_interval)
async with self.client_context() as client:
tries += 1

result = await client.get(status_url, headers=headers)

if result.status_code != 200:
end = time.time()
if end - start > self.max_timeout:
raise Exception(f"Timeout while parsing the file: {job_id}")
if verbose and tries % 10 == 0:
print(".", end="", flush=True)

await asyncio.sleep(self.check_interval)

continue

# Allowed values "PENDING", "SUCCESS", "ERROR", "CANCELED"
result_json = result.json()
status = result_json["status"]
if status == "SUCCESS":
parsed_result = await client.get(result_url, headers=headers)
return parsed_result.json()
elif status == "PENDING":
end = time.time()
if end - start > self.max_timeout:
raise Exception(f"Timeout while parsing the file: {job_id}")
if verbose and tries % 10 == 0:
print(".", end="", flush=True)

await asyncio.sleep(self.check_interval)
else:
error_code = result_json.get("error_code", "No error code found")
error_message = result_json.get(
"error_message", "No error message found"
)
tries += 1
result = await client.get(JOB_STATUS_ROUTE.format(job_id=job_id))
if result.status_code != 200:
end = time.time()
if end - start > self.max_timeout:
raise Exception(f"Timeout while parsing the file: {job_id}")
if verbose and tries % 10 == 0:
print(".", end="", flush=True)
await asyncio.sleep(self.check_interval)
continue

# Allowed values "PENDING", "SUCCESS", "ERROR", "CANCELED"
result_json = result.json()
status = result_json["status"]
if status == "SUCCESS":
parsed_result = await client.get(
JOB_RESULT_URL.format(job_id=job_id, result_type=result_type),
)
return parsed_result.json()

exception_str = f"Job ID: {job_id} failed with status: {status}, Error code: {error_code}, Error message: {error_message}"
raise Exception(exception_str)
elif status == "PENDING":
end = time.time()
if end - start > self.max_timeout:
raise Exception(f"Timeout while parsing the file: {job_id}")
if verbose and tries % 10 == 0:
print(".", end="", flush=True)
await asyncio.sleep(self.check_interval)

else:
error_code = result_json.get("error_code", "No error code found")
error_message = result_json.get(
"error_message", "No error message found"
)

exception_str = f"Job ID: {job_id} failed with status: {status}, Error code: {error_code}, Error message: {error_message}"
raise Exception(exception_str)

async def _aload_data(
self,
Expand Down Expand Up @@ -798,12 +816,11 @@ async def aget_assets(
self, json_result: List[dict], download_path: str, asset_key: str
) -> List[dict]:
"""Download assets (images or charts) from the parsed result."""
headers = {"Authorization": f"Bearer {self.api_key}"}

# Make the download path
if not os.path.exists(download_path):
os.makedirs(download_path)

client = self.aclient
try:
assets = []
for result in json_result:
Expand All @@ -828,18 +845,14 @@ async def aget_assets(

asset["path"] = asset_path
asset["job_id"] = job_id

asset["original_file_path"] = result.get("file_path", None)

asset["page_number"] = page["page"]

with open(asset_path, "wb") as f:
asset_url = f"{self.base_url}/api/parsing/job/{job_id}/result/image/{asset_name}"
async with self.client_context() as client:
res = await client.get(
asset_url, headers=headers, timeout=self.max_timeout
)
res.raise_for_status()
f.write(res.content)
resp = await client.get(asset_url)
resp.raise_for_status()
f.write(resp.content)
assets.append(asset)
return assets
except Exception as e:
Expand Down Expand Up @@ -899,11 +912,10 @@ async def aget_xlsx(
self, json_result: List[dict], download_path: str
) -> List[dict]:
"""Download xlsx from the parsed result."""
headers = {"Authorization": f"Bearer {self.api_key}"}

# make the download path
if not os.path.exists(download_path):
os.makedirs(download_path)
client = self.aclient
try:
xlsx_list = []
for result in json_result:
Expand All @@ -923,12 +935,9 @@ async def aget_xlsx(
xlsx_url = (
f"{self.base_url}/api/parsing/job/{job_id}/result/raw/xlsx"
)
async with self.client_context() as client:
res = await client.get(
xlsx_url, headers=headers, timeout=self.max_timeout
)
res.raise_for_status()
f.write(res.content)
res = await client.get(xlsx_url)
res.raise_for_status()
f.write(res.content)
xlsx_list.append(xlsx)
return xlsx_list

Expand Down

0 comments on commit 530241d

Please sign in to comment.