Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NastyBoget committed Dec 17, 2024
1 parent 549dd76 commit 2396e71
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 35 deletions.
24 changes: 10 additions & 14 deletions dedoc/api/api_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Iterator, List, Optional, Set

from dedoc.api.schema import LineMetadata, ParsedDocument, Table, TreeNode
from dedoc.data_structures.concrete_annotations.attach_annotation import AttachAnnotation
from dedoc.data_structures.concrete_annotations.bold_annotation import BoldAnnotation
from dedoc.data_structures.concrete_annotations.italic_annotation import ItalicAnnotation
Expand All @@ -10,10 +11,6 @@
from dedoc.data_structures.concrete_annotations.table_annotation import TableAnnotation
from dedoc.data_structures.concrete_annotations.underlined_annotation import UnderlinedAnnotation
from dedoc.data_structures.hierarchy_level import HierarchyLevel
from dedoc.data_structures.line_metadata import LineMetadata
from dedoc.data_structures.parsed_document import ParsedDocument
from dedoc.data_structures.table import Table
from dedoc.data_structures.tree_node import TreeNode
from dedoc.extensions import converted_mimes, recognized_mimes


Expand All @@ -39,7 +36,7 @@ def _node2tree(paragraph: TreeNode, depth: int, depths: Set[int] = None) -> str:
space = "".join(space)
node_result = []

node_result.append(f" {space} {paragraph.metadata.hierarchy_level.line_type}&nbsp{paragraph.node_id} ")
node_result.append(f" {space} {paragraph.metadata.paragraph_type}&nbsp{paragraph.node_id} ")
for text in __prettify_text(paragraph.text):
space = [space_symbol] * 4 * (depth - 1) + 4 * [space_symbol]
space = "".join(space)
Expand Down Expand Up @@ -98,7 +95,7 @@ def json2tree(paragraph: TreeNode) -> str:
depths = {d for d in depths if d <= depth}
space = [space_symbol] * 4 * (depth - 1) + 4 * ["-"]
space = __add_vertical_line(depths, space)
node_result.append(f"<p> <tt> <em> {space} {node.metadata.hierarchy_level.line_type}&nbsp{node.node_id} </em> </tt> </p>")
node_result.append(f"<p> <tt> <em> {space} {node.metadata.paragraph_type}&nbsp{node.node_id} </em> </tt> </p>")
for text in __prettify_text(node.text):
space = [space_symbol] * 4 * (depth - 1) + 4 * [space_symbol]
space = __add_vertical_line(depths, space)
Expand Down Expand Up @@ -136,14 +133,14 @@ def json2html(text: str,

ptext = __annotations2html(paragraph=paragraph, table2id=table2id, attach2id=attach2id, tabs=tabs)

if paragraph.metadata.hierarchy_level.line_type in [HierarchyLevel.header, HierarchyLevel.root]:
if paragraph.metadata.paragraph_type in [HierarchyLevel.header, HierarchyLevel.root]:
ptext = f"<strong>{ptext.strip()}</strong>"
elif paragraph.metadata.hierarchy_level.line_type == HierarchyLevel.list_item:
elif paragraph.metadata.paragraph_type == HierarchyLevel.list_item:
ptext = f"<em>{ptext.strip()}</em>"
else:
ptext = ptext.strip()

ptext = f'<p> {"&nbsp;" * tabs} {ptext} <sub> id = {paragraph.node_id} ; type = {paragraph.metadata.hierarchy_level.line_type} </sub></p>'
ptext = f'<p> {"&nbsp;" * tabs} {ptext} <sub> id = {paragraph.node_id} ; type = {paragraph.metadata.paragraph_type} </sub></p>'
if hasattr(paragraph.metadata, "uid"):
ptext = f'<div id="{paragraph.metadata.uid}">{ptext}</div>'
text += ptext
Expand Down Expand Up @@ -259,11 +256,10 @@ def table2html(table: Table, table2id: Dict[str, int]) -> str:
text += ' style="display: none" '
cell_node = TreeNode(
node_id="0",
text=cell.get_text(),
annotations=cell.get_annotations(),
metadata=LineMetadata(page_id=table.metadata.page_id, line_id=0),
subparagraphs=[],
parent=None
text="\n".join([line.text for line in cell.lines]),
annotations=cell.lines[0].annotations if cell.lines else [],
metadata=LineMetadata(page_id=0, line_id=0, paragraph_type=HierarchyLevel.raw_text),
subparagraphs=[]
)
text += f' colspan="{cell.colspan}" rowspan="{cell.rowspan}">{__annotations2html(cell_node, {}, {})}</td>\n'

Expand Down
19 changes: 12 additions & 7 deletions dedoc/api/dedoc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import importlib
import json
import os
import traceback
import tempfile
from typing import Optional

from fastapi import Depends, FastAPI, File, Request, Response, UploadFile
Expand All @@ -18,6 +18,7 @@
from dedoc.common.exceptions.dedoc_error import DedocError
from dedoc.common.exceptions.missing_file_error import MissingFileError
from dedoc.config import get_config
from dedoc.utils.utils import save_upload_file

config = get_config()
logger = config["logger"]
Expand Down Expand Up @@ -64,7 +65,10 @@ async def upload(request: Request, file: UploadFile = File(...), query_params: Q
if not file or file.filename == "":
raise MissingFileError("Error: Missing content in request_post file parameter", version=dedoc.version.__version__)

document_tree = await process_handler.handle(request=request, parameters=parameters, file=file)
with tempfile.TemporaryDirectory() as tmpdir:
file_path = save_upload_file(file, tmpdir)
document_tree = await process_handler.handle(request=request, parameters=parameters, file_path=file_path, tmpdir=tmpdir)

if document_tree is None:
return JSONResponse(status_code=499, content={})

Expand All @@ -88,24 +92,25 @@ async def upload(request: Request, file: UploadFile = File(...), query_params: Q
return HTMLResponse(content=html_content)

if return_format == "ujson":
return UJSONResponse(content=document_tree.to_api_schema().model_dump())
return UJSONResponse(content=document_tree.model_dump())

if return_format == "collapsed_tree":
html_content = json2collapsed_tree(paragraph=document_tree.content.structure)
return HTMLResponse(content=html_content)

if return_format == "pretty_json":
return PlainTextResponse(content=json.dumps(document_tree.to_api_schema().model_dump(), ensure_ascii=False, indent=2))
return PlainTextResponse(content=json.dumps(document_tree.model_dump(), ensure_ascii=False, indent=2))

logger.info(f"Send result. File {file.filename} with parameters {parameters}")
return ORJSONResponse(content=document_tree.to_api_schema().model_dump())
return ORJSONResponse(content=document_tree.model_dump())


@app.get("/upload_example")
async def upload_example(request: Request, file_name: str, return_format: Optional[str] = None) -> Response:
file_path = os.path.join(static_path, "examples", file_name)
parameters = {} if return_format is None else {"return_format": return_format}
document_tree = await process_handler.handle(request=request, parameters=parameters, file=file_path)
with tempfile.TemporaryDirectory() as tmpdir:
document_tree = await process_handler.handle(request=request, parameters=parameters, file_path=file_path, tmpdir=tmpdir)

if return_format == "html":
html_page = json2html(
Expand All @@ -116,7 +121,7 @@ async def upload_example(request: Request, file_name: str, return_format: Option
tabs=0
)
return HTMLResponse(content=html_page)
return ORJSONResponse(content=document_tree.to_api_schema().model_dump(), status_code=200)
return ORJSONResponse(content=document_tree.model_dump(), status_code=200)


@app.exception_handler(DedocError)
Expand Down
23 changes: 9 additions & 14 deletions dedoc/api/process_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,18 @@
import os
import pickle
import signal
import tempfile
import traceback
from multiprocessing import Process, Queue
from typing import Optional, Union
from typing import Optional
from urllib.request import Request

from anyio import get_cancelled_exc_class
from fastapi import UploadFile

from dedoc import DedocManager
from dedoc.api.cancellation import cancel_on_disconnect
from dedoc.api.schema import ParsedDocument
from dedoc.common.exceptions.dedoc_error import DedocError
from dedoc.config import get_config
from dedoc.data_structures import ParsedDocument
from dedoc.utils.utils import save_upload_file


class ProcessHandler:
Expand All @@ -41,7 +38,7 @@ def __init__(self, logger: logging.Logger) -> None:
self.process = Process(target=self.__parse_file, args=[self.input_queue, self.output_queue])
self.process.start()

async def handle(self, request: Request, parameters: dict, file: Union[UploadFile, str]) -> Optional[ParsedDocument]:
async def handle(self, request: Request, parameters: dict, file_path: str, tmpdir: str) -> Optional[ParsedDocument]:
"""
Handle request in a separate process.
Checks for client disconnection and terminate the child process if client disconnected.
Expand All @@ -50,7 +47,7 @@ async def handle(self, request: Request, parameters: dict, file: Union[UploadFil
self.__init__(logger=self.logger)

self.logger.info("Putting file to the input queue")
self.input_queue.put(pickle.dumps((parameters, file)), block=True)
self.input_queue.put(pickle.dumps((parameters, file_path, tmpdir)), block=True)

loop = asyncio.get_running_loop()
async with cancel_on_disconnect(request, self.logger):
Expand Down Expand Up @@ -88,17 +85,15 @@ def __parse_file(self, input_queue: Queue, output_queue: Queue) -> None:

while True:
try:
parameters, file = pickle.loads(input_queue.get(block=True))
parameters, file_path, tmp_dir = pickle.loads(input_queue.get(block=True))
manager.logger.info("Parsing process got task from the input queue")
return_format = str(parameters.get("return_format", "json")).lower()
with tempfile.TemporaryDirectory() as tmpdir:
file_path = file if isinstance(file, str) else save_upload_file(file, tmpdir)
document_tree = manager.parse(file_path, parameters={**dict(parameters), "attachments_dir": tmpdir})
document_tree = manager.parse(file_path, parameters={**dict(parameters), "attachments_dir": tmp_dir})

if return_format == "html":
self.__add_base64_info_to_attachments(document_tree, tmpdir)
if return_format == "html":
self.__add_base64_info_to_attachments(document_tree, tmp_dir)

output_queue.put(pickle.dumps(document_tree), block=True)
output_queue.put(pickle.dumps(document_tree.to_api_schema()), block=True)
manager.logger.info("Parsing process put task to the output queue")
except Exception as e:
tb = traceback.format_exc()
Expand Down

0 comments on commit 2396e71

Please sign in to comment.