diff --git a/README.md b/README.md index 049c1c6f..54d63029 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,8 @@ pip install streamlit surya_gui ``` +Pass the `--math` command line argument to use the math detection model instead of the default model. This will detect math better, but will be worse at everything else. + ## OCR (text recognition) You can OCR text in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page. @@ -81,6 +83,7 @@ The `results.json` file will contain a json dictionary where the keys are the in - `text_lines` - the detected text and bounding boxes for each line - `text` - the text in the line + - `confidence` - the confidence of the model in the detected text (0-1) - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. - `languages` - the languages specified for the page @@ -120,12 +123,14 @@ surya_detect DATA_PATH --images - `--images` will save images of the pages and detected text lines (optional) - `--max` specifies the maximum number of pages to process if you don't want to process everything - `--results_dir` specifies the directory to save results to instead of the default +- `--math` uses a specialized math detection model instead of the default model. This will be better at math, but worse at everything else. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: - `bboxes` - detected bounding boxes for text - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. + - `confidence` - the confidence of the model in the detected text (0-1) - `vertical_lines` - vertical lines detected in the document - `bbox` - the axis-aligned line coordinates. - `horizontal_lines` - horizontal lines detected in the document diff --git a/benchmark/detection.py b/benchmark/detection.py index a75432e2..69bfcb53 100644 --- a/benchmark/detection.py +++ b/benchmark/detection.py @@ -8,7 +8,7 @@ from surya.benchmark.tesseract import tesseract_parallel from surya.model.detection.segformer import load_model, load_processor from surya.input.processing import open_pdf, get_page_images -from surya.detection import batch_detection +from surya.detection import batch_text_detection from surya.postprocessing.heatmap import draw_polys_on_image from surya.postprocessing.util import rescale_bbox from surya.settings import settings @@ -54,7 +54,7 @@ def main(): correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes]) start = time.time() - predictions = batch_detection(images, model, processor) + predictions = batch_text_detection(images, model, processor) surya_time = time.time() - start start = time.time() diff --git a/benchmark/gcloud_label.py b/benchmark/gcloud_label.py new file mode 100644 index 00000000..5c9012df --- /dev/null +++ b/benchmark/gcloud_label.py @@ -0,0 +1,149 @@ +import argparse +import json +from collections import defaultdict + +import datasets +from surya.settings import settings +from google.cloud import vision +import hashlib +import os +from tqdm import tqdm +import io + +DATA_DIR = os.path.join(settings.BASE_DIR, settings.DATA_DIR) +RESULT_DIR = os.path.join(settings.BASE_DIR, settings.RESULT_DIR) + +rtl_langs = ["ar", "fa", "he", "ur", "ps", "sd", "yi", "ug"] + +def polygon_to_bbox(polygon): + x = [vertex["x"] for vertex in polygon["vertices"]] + y = [vertex["y"] for vertex in polygon["vertices"]] + return (min(x), min(y), max(x), max(y)) + + +def text_with_break(text, property, is_rtl=False): + break_type = None + prefix = False + if property: + if "detectedBreak" in property: + if "type" in property["detectedBreak"]: + break_type = property["detectedBreak"]["type"] + if "isPrefix" in property["detectedBreak"]: + prefix = property["detectedBreak"]["isPrefix"] + break_char = "" + if break_type == 1: + break_char = " " + if break_type == 5: + break_char = "\n" + + if is_rtl: + prefix = not prefix + + if prefix: + text = break_char + text + else: + text = text + break_char + return text + + +def bbox_overlap_pct(box1, box2): + x1, y1, x2, y2 = box1 + x3, y3, x4, y4 = box2 + dx = min(x2, x4) - max(x1, x3) + dy = min(y2, y4) - max(y1, y3) + if (dx >= 0) and (dy >= 0): + return dx * dy / ((x2 - x1) * (y2 - y1)) + return 0 + + +def annotate_image(img, client, language, cache_dir): + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, format=img.format) + img_byte_arr = img_byte_arr.getvalue() + + img_hash = hashlib.sha256(img_byte_arr).hexdigest() + cache_path = os.path.join(cache_dir, f"{img_hash}.json") + if os.path.exists(cache_path): + with open(cache_path, "r") as f: + response = json.load(f) + return response + + gc_image = vision.Image(content=img_byte_arr) + context = vision.ImageContext(language_hints=[language]) + response = client.document_text_detection(image=gc_image, image_context=context) + response_json = vision.AnnotateImageResponse.to_json(response) + loaded_response = json.loads(response_json) + with open(cache_path, "w+") as f: + json.dump(loaded_response, f) + return loaded_response + + +def get_line_text(response, lines, is_rtl=False): + document = response["fullTextAnnotation"] + + bounds = [] + for page in document["pages"]: + for block in page["blocks"]: + for paragraph in block["paragraphs"]: + for word in paragraph["words"]: + for symbol in word["symbols"]: + bounds.append((symbol["boundingBox"], symbol["text"], symbol.get("property"))) + + bboxes = [(polygon_to_bbox(b[0]), text_with_break(b[1], b[2], is_rtl)) for b in bounds] + line_boxes = defaultdict(list) + for i, bbox in enumerate(bboxes): + max_overlap_pct = 0 + max_overlap_idx = None + for j, line in enumerate(lines): + overlap = bbox_overlap_pct(bbox[0], line) + if overlap > max_overlap_pct: + max_overlap_pct = overlap + max_overlap_idx = j + if max_overlap_idx is not None: + line_boxes[max_overlap_idx].append(bbox) + + ocr_lines = [] + for j, line in enumerate(lines): + ocr_bboxes = sorted(line_boxes[j], key=lambda x: x[0][0]) + if is_rtl: + ocr_bboxes = list(reversed(ocr_bboxes)) + ocr_text = "".join([b[1] for b in ocr_bboxes]) + ocr_lines.append(ocr_text) + + assert len(ocr_lines) == len(lines) + return ocr_lines + + +def main(): + parser = argparse.ArgumentParser(description="Label text in dataset with google cloud vision.") + parser.add_argument("--project_id", type=str, help="Google cloud project id.", required=True) + parser.add_argument("--service_account", type=str, help="Path to service account json.", required=True) + parser.add_argument("--max", type=int, help="Maximum number of pages to label.", default=None) + args = parser.parse_args() + + cache_dir = os.path.join(DATA_DIR, "gcloud_cache") + os.makedirs(cache_dir, exist_ok=True) + + dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split="train") + client = vision.ImageAnnotatorClient.from_service_account_json(args.service_account) + + all_gc_lines = [] + for i in tqdm(range(len(dataset))): + img = dataset[i]["image"] + lines = dataset[i]["bboxes"] + language = dataset[i]["language"] + + response = annotate_image(img, client, language, cache_dir) + ocr_lines = get_line_text(response, lines, is_rtl=language in rtl_langs) + + all_gc_lines.append(ocr_lines) + + if args.max is not None and i >= args.max: + break + + with open(os.path.join(RESULT_DIR, "gcloud_ocr.json"), "w+") as f: + json.dump(all_gc_lines, f) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmark/recognition.py b/benchmark/recognition.py index 8355df9f..c6ce9ca8 100644 --- a/benchmark/recognition.py +++ b/benchmark/recognition.py @@ -148,9 +148,9 @@ def main(): pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png" ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png" pred_text = [l.text for l in pred.text_lines] - pred_image = draw_text_on_image(bbox, pred_text, image.size) + pred_image = draw_text_on_image(bbox, pred_text, image.size, lang) pred_image.save(os.path.join(result_path, pred_image_name)) - ref_image = draw_text_on_image(bbox, ref_text, image.size) + ref_image = draw_text_on_image(bbox, ref_text, image.size, lang) ref_image.save(os.path.join(result_path, ref_image_name)) image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png")) diff --git a/detect_layout.py b/detect_layout.py new file mode 100644 index 00000000..8eeab91e --- /dev/null +++ b/detect_layout.py @@ -0,0 +1,98 @@ +import argparse +import copy +import json +from collections import defaultdict + +from surya.detection import batch_text_detection +from surya.input.load import load_from_folder, load_from_file +from surya.layout import batch_layout_detection +from surya.model.detection.segformer import load_model, load_processor +from surya.postprocessing.heatmap import draw_polys_on_image +from surya.settings import settings +import os + + +def main(): + parser = argparse.ArgumentParser(description="Detect layout of an input file or folder (PDFs or image).") + parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect layout in.") + parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya")) + parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) + parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False) + parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False) + args = parser.parse_args() + + print("Layout detection is currently in beta! There may be issues with the output.") + + model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + det_model = load_model() + det_processor = load_processor() + + if os.path.isdir(args.input_path): + images, names = load_from_folder(args.input_path, args.max) + folder_name = os.path.basename(args.input_path) + else: + images, names = load_from_file(args.input_path, args.max) + folder_name = os.path.basename(args.input_path).split(".")[0] + + line_predictions = batch_text_detection(images, det_model, det_processor) + + layout_predictions = batch_layout_detection(images, model, processor, line_predictions) + result_path = os.path.join(args.results_dir, folder_name) + os.makedirs(result_path, exist_ok=True) + + for idx, (layout_pred, line_pred, name) in enumerate(zip(layout_predictions, line_predictions, names)): + blocks = layout_pred.bboxes + for line in line_pred.vertical_lines: + new_blocks = [] + for block in blocks: + block_modified = False + + if line.bbox[0] > block.bbox[0] and line.bbox[2] < block.bbox[2]: + overlap_pct = (min(line.bbox[3], block.bbox[3]) - max(line.bbox[1], block.bbox[1])) / ( + block.bbox[3] - block.bbox[1]) + if overlap_pct > 0.5: + block1 = copy.deepcopy(block) + block2 = copy.deepcopy(block) + block1.bbox[2] = line.bbox[0] + block2.bbox[0] = line.bbox[2] + new_blocks.append(block1) + new_blocks.append(block2) + block_modified = True + if not block_modified: + new_blocks.append(block) + blocks = new_blocks + layout_pred.bboxes = blocks + + if args.images: + for idx, (image, layout_pred, line_pred, name) in enumerate(zip(images, layout_predictions, line_predictions, names)): + polygons = [p.polygon for p in layout_pred.bboxes] + labels = [p.label for p in layout_pred.bboxes] + bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels) + bbox_image.save(os.path.join(result_path, f"{name}_{idx}_layout.png")) + + if args.debug: + heatmap = layout_pred.segmentation_map + heatmap.save(os.path.join(result_path, f"{name}_{idx}_segmentation.png")) + + predictions_by_page = defaultdict(list) + for idx, (pred, name, image) in enumerate(zip(layout_predictions, names, images)): + out_pred = pred.model_dump(exclude=["segmentation_map"]) + out_pred["page"] = len(predictions_by_page[name]) + 1 + predictions_by_page[name].append(out_pred) + + with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: + json.dump(predictions_by_page, f, ensure_ascii=False) + + print(f"Wrote results to {result_path}") + + +if __name__ == "__main__": + main() + + + + + + + diff --git a/detect_text.py b/detect_text.py index 88aab6d7..8a0a81bd 100644 --- a/detect_text.py +++ b/detect_text.py @@ -5,7 +5,7 @@ from surya.input.load import load_from_folder, load_from_file from surya.model.detection.segformer import load_model, load_processor -from surya.detection import batch_detection +from surya.detection import batch_text_detection from surya.postprocessing.affinity import draw_lines_on_image from surya.postprocessing.heatmap import draw_polys_on_image from surya.settings import settings @@ -20,10 +20,12 @@ def main(): parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False) parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False) + parser.add_argument("--math", action="store_true", help="Use math model for detection", default=False) args = parser.parse_args() - model = load_model() - processor = load_processor() + checkpoint = settings.DETECTOR_MATH_MODEL_CHECKPOINT if args.math else settings.DETECTOR_MODEL_CHECKPOINT + model = load_model(checkpoint=checkpoint) + processor = load_processor(checkpoint=checkpoint) if os.path.isdir(args.input_path): images, names = load_from_folder(args.input_path, args.max) @@ -32,7 +34,7 @@ def main(): images, names = load_from_file(args.input_path, args.max) folder_name = os.path.basename(args.input_path).split(".")[0] - predictions = batch_detection(images, model, processor) + predictions = batch_text_detection(images, model, processor) result_path = os.path.join(args.results_dir, folder_name) os.makedirs(result_path, exist_ok=True) @@ -58,7 +60,7 @@ def main(): out_pred["page"] = len(predictions_by_page[name]) + 1 predictions_by_page[name].append(out_pred) - with open(os.path.join(result_path, "results.json"), "w+") as f: + with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(predictions_by_page, f, ensure_ascii=False) print(f"Wrote results to {result_path}") diff --git a/ocr_app.py b/ocr_app.py index 26bd61c6..1994ad19 100644 --- a/ocr_app.py +++ b/ocr_app.py @@ -1,8 +1,11 @@ +import argparse import io +from typing import List import pypdfium2 import streamlit as st -from surya.detection import batch_detection +from surya.detection import batch_text_detection +from surya.layout import batch_layout_detection from surya.model.detection.segformer import load_model, load_processor from surya.model.recognition.model import load_model as load_rec_model from surya.model.recognition.processor import load_processor as load_rec_processor @@ -12,12 +15,22 @@ from PIL import Image from surya.languages import CODE_TO_LANGUAGE from surya.input.langs import replace_lang_with_code -from surya.schema import OCRResult, DetectionResult - +from surya.schema import OCRResult, TextDetectionResult, LayoutResult +from surya.settings import settings +import os + +parser = argparse.ArgumentParser(description="Run OCR on an image or PDF.") +parser.add_argument("--math", action="store_true", help="Use math model for detection", default=False) +try: + args = parser.parse_args() +except SystemExit as e: + print(f"Error parsing arguments: {e}") + os._exit(e.code) @st.cache_resource() def load_det_cached(): - return load_model(), load_processor() + checkpoint = settings.DETECTOR_MATH_MODEL_CHECKPOINT if args.math else settings.DETECTOR_MODEL_CHECKPOINT + return load_model(checkpoint=checkpoint), load_processor(checkpoint=checkpoint) @st.cache_resource() @@ -25,21 +38,35 @@ def load_rec_cached(): return load_rec_model(), load_rec_processor() -def text_detection(img) -> DetectionResult: - pred = batch_detection([img], det_model, det_processor)[0] +@st.cache_resource() +def load_layout_cached(): + return load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT), load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + + +def text_detection(img) -> TextDetectionResult: + pred = batch_text_detection([img], det_model, det_processor)[0] polygons = [p.polygon for p in pred.bboxes] det_img = draw_polys_on_image(polygons, img.copy()) return det_img, pred +def layout_detection(img) -> LayoutResult: + _, det_pred = text_detection(img) + pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0] + polygons = [p.polygon for p in pred.bboxes] + labels = [p.label for p in pred.bboxes] + layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels) + return layout_img, pred + + # Function for OCR -def ocr(img, langs) -> OCRResult: +def ocr(img, langs: List[str]) -> OCRResult: replace_lang_with_code(langs) img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0] bboxes = [l.bbox for l in img_pred.text_lines] text = [l.text for l in img_pred.text_lines] - rec_img = draw_text_on_image(bboxes, text, img.size) + rec_img = draw_text_on_image(bboxes, text, img.size, langs, has_math="_math" in langs) return rec_img, img_pred @@ -72,6 +99,7 @@ def page_count(pdf_file): det_model, det_processor = load_det_cached() rec_model, rec_processor = load_rec_cached() +layout_model, layout_processor = load_layout_cached() st.markdown(""" @@ -106,6 +134,7 @@ def page_count(pdf_file): text_det = st.sidebar.button("Run Text Detection") text_rec = st.sidebar.button("Run OCR") +layout_det = st.sidebar.button("Run Layout Detection [BETA]") # Run Text Detection if text_det and pil_image is not None: @@ -114,6 +143,14 @@ def page_count(pdf_file): st.image(det_img, caption="Detected Text", use_column_width=True) st.json(pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True) + +# Run layout +if layout_det and pil_image is not None: + layout_img, pred = layout_detection(pil_image) + with col1: + st.image(layout_img, caption="Detected Layout", use_column_width=True) + st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True) + # Run OCR if text_rec and pil_image is not None: rec_img, pred = ocr(pil_image, languages) diff --git a/ocr_text.py b/ocr_text.py index 50f4c7af..75271066 100644 --- a/ocr_text.py +++ b/ocr_text.py @@ -59,10 +59,10 @@ def main(): predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor) if args.images: - for idx, (name, image, pred) in enumerate(zip(names, images, predictions_by_image)): + for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)): bboxes = [l.bbox for l in pred.text_lines] pred_text = [l.text for l in pred.text_lines] - page_image = draw_text_on_image(bboxes, pred_text, image.size) + page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs) page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png")) out_preds = defaultdict(list) @@ -71,7 +71,7 @@ def main(): out_pred["page"] = len(out_preds[name]) + 1 out_preds[name].append(out_pred) - with open(os.path.join(result_path, "results.json"), "w+") as f: + with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(out_preds, f, ensure_ascii=False) print(f"Wrote results to {result_path}") diff --git a/poetry.lock b/poetry.lock index e85ab6a9..f0937061 100644 --- a/poetry.lock +++ b/poetry.lock @@ -933,6 +933,20 @@ smb = ["smbprotocol"] ssh = ["paramiko"] tqdm = ["tqdm"] +[[package]] +name = "ftfy" +version = "6.1.3" +description = "Fixes mojibake and other problems with Unicode, after the fact" +optional = false +python-versions = ">=3.8,<4" +files = [ + {file = "ftfy-6.1.3-py3-none-any.whl", hash = "sha256:e49c306c06a97f4986faa7a8740cfe3c13f3106e85bcec73eb629817e671557c"}, + {file = "ftfy-6.1.3.tar.gz", hash = "sha256:693274aead811cff24c1e8784165aa755cd2f6e442a5ec535c7d697f6422a422"}, +] + +[package.dependencies] +wcwidth = ">=0.2.12,<0.3.0" + [[package]] name = "gitdb" version = "4.0.11" @@ -964,6 +978,77 @@ gitdb = ">=4.0.1,<5" [package.extras] test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "sumtypes"] +[[package]] +name = "greenlet" +version = "3.0.3" +description = "Lightweight in-process concurrent programming" +optional = false +python-versions = ">=3.7" +files = [ + {file = "greenlet-3.0.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9da2bd29ed9e4f15955dd1595ad7bc9320308a3b766ef7f837e23ad4b4aac31a"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d353cadd6083fdb056bb46ed07e4340b0869c305c8ca54ef9da3421acbdf6881"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dca1e2f3ca00b84a396bc1bce13dd21f680f035314d2379c4160c98153b2059b"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ed7fb269f15dc662787f4119ec300ad0702fa1b19d2135a37c2c4de6fadfd4a"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd4f49ae60e10adbc94b45c0b5e6a179acc1736cf7a90160b404076ee283cf83"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:73a411ef564e0e097dbe7e866bb2dda0f027e072b04da387282b02c308807405"}, + {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7f362975f2d179f9e26928c5b517524e89dd48530a0202570d55ad6ca5d8a56f"}, + {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:649dde7de1a5eceb258f9cb00bdf50e978c9db1b996964cd80703614c86495eb"}, + {file = "greenlet-3.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:68834da854554926fbedd38c76e60c4a2e3198c6fbed520b106a8986445caaf9"}, + {file = "greenlet-3.0.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:b1b5667cced97081bf57b8fa1d6bfca67814b0afd38208d52538316e9422fc61"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52f59dd9c96ad2fc0d5724107444f76eb20aaccb675bf825df6435acb7703559"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:afaff6cf5200befd5cec055b07d1c0a5a06c040fe5ad148abcd11ba6ab9b114e"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe754d231288e1e64323cfad462fcee8f0288654c10bdf4f603a39ed923bef33"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2797aa5aedac23af156bbb5a6aa2cd3427ada2972c828244eb7d1b9255846379"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7f009caad047246ed379e1c4dbcb8b020f0a390667ea74d2387be2998f58a22"}, + {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c5e1536de2aad7bf62e27baf79225d0d64360d4168cf2e6becb91baf1ed074f3"}, + {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:894393ce10ceac937e56ec00bb71c4c2f8209ad516e96033e4b3b1de270e200d"}, + {file = "greenlet-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:1ea188d4f49089fc6fb283845ab18a2518d279c7cd9da1065d7a84e991748728"}, + {file = "greenlet-3.0.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:70fb482fdf2c707765ab5f0b6655e9cfcf3780d8d87355a063547b41177599be"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4d1ac74f5c0c0524e4a24335350edad7e5f03b9532da7ea4d3c54d527784f2e"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:149e94a2dd82d19838fe4b2259f1b6b9957d5ba1b25640d2380bea9c5df37676"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15d79dd26056573940fcb8c7413d84118086f2ec1a8acdfa854631084393efcc"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b7db1ebff4ba09aaaeae6aa491daeb226c8150fc20e836ad00041bcb11230"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fcd2469d6a2cf298f198f0487e0a5b1a47a42ca0fa4dfd1b6862c999f018ebbf"}, + {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1f672519db1796ca0d8753f9e78ec02355e862d0998193038c7073045899f305"}, + {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2516a9957eed41dd8f1ec0c604f1cdc86758b587d964668b5b196a9db5bfcde6"}, + {file = "greenlet-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:bba5387a6975598857d86de9eac14210a49d554a77eb8261cc68b7d082f78ce2"}, + {file = "greenlet-3.0.3-cp37-cp37m-macosx_11_0_universal2.whl", hash = "sha256:5b51e85cb5ceda94e79d019ed36b35386e8c37d22f07d6a751cb659b180d5274"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:daf3cb43b7cf2ba96d614252ce1684c1bccee6b2183a01328c98d36fcd7d5cb0"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99bf650dc5d69546e076f413a87481ee1d2d09aaaaaca058c9251b6d8c14783f"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dd6e660effd852586b6a8478a1d244b8dc90ab5b1321751d2ea15deb49ed414"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3391d1e16e2a5a1507d83e4a8b100f4ee626e8eca43cf2cadb543de69827c4c"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e1f145462f1fa6e4a4ae3c0f782e580ce44d57c8f2c7aae1b6fa88c0b2efdb41"}, + {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1a7191e42732df52cb5f39d3527217e7ab73cae2cb3694d241e18f53d84ea9a7"}, + {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0448abc479fab28b00cb472d278828b3ccca164531daab4e970a0458786055d6"}, + {file = "greenlet-3.0.3-cp37-cp37m-win32.whl", hash = "sha256:b542be2440edc2d48547b5923c408cbe0fc94afb9f18741faa6ae970dbcb9b6d"}, + {file = "greenlet-3.0.3-cp37-cp37m-win_amd64.whl", hash = "sha256:01bc7ea167cf943b4c802068e178bbf70ae2e8c080467070d01bfa02f337ee67"}, + {file = "greenlet-3.0.3-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:1996cb9306c8595335bb157d133daf5cf9f693ef413e7673cb07e3e5871379ca"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ddc0f794e6ad661e321caa8d2f0a55ce01213c74722587256fb6566049a8b04"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9db1c18f0eaad2f804728c67d6c610778456e3e1cc4ab4bbd5eeb8e6053c6fc"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7170375bcc99f1a2fbd9c306f5be8764eaf3ac6b5cb968862cad4c7057756506"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b66c9c1e7ccabad3a7d037b2bcb740122a7b17a53734b7d72a344ce39882a1b"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:098d86f528c855ead3479afe84b49242e174ed262456c342d70fc7f972bc13c4"}, + {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:81bb9c6d52e8321f09c3d165b2a78c680506d9af285bfccbad9fb7ad5a5da3e5"}, + {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fd096eb7ffef17c456cfa587523c5f92321ae02427ff955bebe9e3c63bc9f0da"}, + {file = "greenlet-3.0.3-cp38-cp38-win32.whl", hash = "sha256:d46677c85c5ba00a9cb6f7a00b2bfa6f812192d2c9f7d9c4f6a55b60216712f3"}, + {file = "greenlet-3.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:419b386f84949bf0e7c73e6032e3457b82a787c1ab4a0e43732898a761cc9dbf"}, + {file = "greenlet-3.0.3-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:da70d4d51c8b306bb7a031d5cff6cc25ad253affe89b70352af5f1cb68e74b53"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:086152f8fbc5955df88382e8a75984e2bb1c892ad2e3c80a2508954e52295257"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d73a9fe764d77f87f8ec26a0c85144d6a951a6c438dfe50487df5595c6373eac"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7dcbe92cc99f08c8dd11f930de4d99ef756c3591a5377d1d9cd7dd5e896da71"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1551a8195c0d4a68fac7a4325efac0d541b48def35feb49d803674ac32582f61"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:64d7675ad83578e3fc149b617a444fab8efdafc9385471f868eb5ff83e446b8b"}, + {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b37eef18ea55f2ffd8f00ff8fe7c8d3818abd3e25fb73fae2ca3b672e333a7a6"}, + {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:77457465d89b8263bca14759d7c1684df840b6811b2499838cc5b040a8b5b113"}, + {file = "greenlet-3.0.3-cp39-cp39-win32.whl", hash = "sha256:57e8974f23e47dac22b83436bdcf23080ade568ce77df33159e019d161ce1d1e"}, + {file = "greenlet-3.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:c5ee858cfe08f34712f548c3c363e807e7186f03ad7a5039ebadb29e8c6be067"}, + {file = "greenlet-3.0.3.tar.gz", hash = "sha256:43374442353259554ce33599da8b692d5aa96f8976d567d4badf263371fbe491"}, +] + +[package.extras] +docs = ["Sphinx", "furo"] +test = ["objgraph", "psutil"] + [[package]] name = "huggingface-hub" version = "0.20.3" @@ -2368,6 +2453,26 @@ files = [ docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] +[[package]] +name = "playwright" +version = "1.41.2" +description = "A high-level API to automate web browsers" +optional = false +python-versions = ">=3.8" +files = [ + {file = "playwright-1.41.2-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:cf68335a5dfa4038fa797a4ba0105faee0094ebbb372547d7a27feec5b23c672"}, + {file = "playwright-1.41.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:431e3a05f8c99147995e2b3e8475d07818745294fd99f1510b61756e73bdcf68"}, + {file = "playwright-1.41.2-py3-none-macosx_11_0_universal2.whl", hash = "sha256:0608717cbf291a625ba6f751061af0fc0cc9bdace217e69d87b1eb1383b03406"}, + {file = "playwright-1.41.2-py3-none-manylinux1_x86_64.whl", hash = "sha256:4bf214d812092cf5b9b9648ba84611aa35e28685519911342a7da3a3031f9ed6"}, + {file = "playwright-1.41.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eaa17ab44622c447de26ed8f7d99912719568d8dbc3a9db0e07f0ae1487709d9"}, + {file = "playwright-1.41.2-py3-none-win32.whl", hash = "sha256:edb210a015e70bb0d328bf1c9b65fa3a08361f33e4d7c4ddd1ad2adb6d9b4479"}, + {file = "playwright-1.41.2-py3-none-win_amd64.whl", hash = "sha256:71ead0f33e00f5a8533c037c647938b99f219436a1b27d4ba4de4e6bf0567278"}, +] + +[package.dependencies] +greenlet = "3.0.3" +pyee = "11.0.1" + [[package]] name = "prometheus-client" version = "0.19.0" @@ -2683,6 +2788,23 @@ numpy = ">=1.16.4" carto = ["pydeck-carto"] jupyter = ["ipykernel (>=5.1.2)", "ipython (>=5.8.0)", "ipywidgets (>=7,<8)", "traitlets (>=4.3.2)"] +[[package]] +name = "pyee" +version = "11.0.1" +description = "A rough port of Node.js's EventEmitter to Python with a few tricks of its own" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyee-11.0.1-py3-none-any.whl", hash = "sha256:9bcc9647822234f42c228d88de63d0f9ffa881e87a87f9d36ddf5211f6ac977d"}, + {file = "pyee-11.0.1.tar.gz", hash = "sha256:a642c51e3885a33ead087286e35212783a4e9b8d6514a10a5db4e57ac57b2b29"}, +] + +[package.dependencies] +typing-extensions = "*" + +[package.extras] +dev = ["black", "flake8", "flake8-black", "isort", "jupyter-console", "mkdocs", "mkdocs-include-markdown-plugin", "mkdocstrings[python]", "pytest", "pytest-asyncio", "pytest-trio", "toml", "tox", "trio", "trio", "trio-typing", "twine", "twisted", "validate-pyproject[all]"] + [[package]] name = "pygments" version = "2.17.2" @@ -4623,4 +4745,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13,!=3.9.7" -content-hash = "b6abaf81bb850c204b073e638c539f47a0c2bf1cfb46dbce2482265beed73198" +content-hash = "9725cc159fb131e6b37e7a52c6f362d888ca6a61423acef8b62bd12c4fedaa33" diff --git a/pyproject.toml b/pyproject.toml index 710ee004..8aff4834 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "surya-ocr" -version = "0.2.2" +version = "0.2.3" description = "OCR and line detection in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" @@ -14,7 +14,8 @@ include = [ "detect_text.py", "ocr_text.py", "ocr_app.py", - "run_ocr_app.py" + "run_ocr_app.py", + "detect_layout.py" ] [tool.poetry.dependencies] @@ -29,6 +30,7 @@ pypdfium2 = "^4.25.0" opencv-python = "^4.9.0.80" tabulate = "^0.9.0" filetype = "^1.2.0" +ftfy = "^6.1.3" [tool.poetry.group.dev.dependencies] jupyter = "^1.0.0" @@ -39,10 +41,12 @@ datasets = "^2.16.1" rapidfuzz = "^3.6.1" arabic-reshaper = "^3.0.0" streamlit = "^1.31.0" +playwright = "^1.41.2" [tool.poetry.scripts] surya_detect = "detect_text:main" surya_ocr = "ocr_text:main" +surya_layout = "detect_layout:main" surya_gui = "run_ocr_app:run_app" [build-system] diff --git a/run_ocr_app.py b/run_ocr_app.py index 27235fed..51c13b1e 100644 --- a/run_ocr_app.py +++ b/run_ocr_app.py @@ -1,8 +1,17 @@ +import argparse import subprocess import os def run_app(): + parser = argparse.ArgumentParser(description="Run the streamlit OCR app") + parser.add_argument("--math", action="store_true", help="Use math model for detection", default=False) + args = parser.parse_args() + cur_dir = os.path.dirname(os.path.abspath(__file__)) ocr_app_path = os.path.join(cur_dir, "ocr_app.py") - subprocess.run(["streamlit", "run", ocr_app_path]) \ No newline at end of file + cmd = ["streamlit", "run", ocr_app_path] + if args.math: + cmd.append("--") + cmd.append("--math") + subprocess.run(cmd) \ No newline at end of file diff --git a/surya/detection.py b/surya/detection.py index 6a9efa6f..db138bd6 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Tuple import cv2 import torch @@ -7,7 +7,7 @@ from surya.postprocessing.heatmap import get_and_clean_boxes from surya.postprocessing.affinity import get_vertical_lines, get_horizontal_lines from surya.input.processing import prepare_image, split_image -from surya.schema import DetectionResult +from surya.schema import TextDetectionResult from surya.settings import settings from tqdm import tqdm @@ -15,15 +15,16 @@ def get_batch_size(): batch_size = settings.DETECTOR_BATCH_SIZE if batch_size is None: - batch_size = 8 + batch_size = 6 if settings.TORCH_DEVICE_MODEL == "cuda": - batch_size = 32 + batch_size = 24 return batch_size -def batch_detection(images: List, model, processor) -> List[DetectionResult]: +def batch_detection(images: List, model, processor) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]: assert all([isinstance(image, Image.Image) for image in images]) batch_size = get_batch_size() + heatmap_count = model.config.num_labels images = [image.convert("RGB") for image in images] orig_sizes = [image.size for image in images] @@ -51,41 +52,42 @@ def batch_detection(images: List, model, processor) -> List[DetectionResult]: logits = pred.logits for j in range(logits.shape[0]): - heatmap = logits[j, 0, :, :].detach().cpu().numpy().astype(np.float32) - affinity_map = logits[j, 1, :, :].detach().cpu().numpy().astype(np.float32) + heatmaps = [] + for k in range(heatmap_count): + heatmap = logits[j, k, :, :].detach().cpu().numpy().astype(np.float32) + heatmap_shape = list(heatmap.shape) - heatmap_shape = list(heatmap.shape) - correct_shape = [processor.size["height"], processor.size["width"]] - cv2_size = list(reversed(correct_shape)) # opencv uses (width, height) instead of (height, width) + correct_shape = [processor.size["height"], processor.size["width"]] + cv2_size = list(reversed(correct_shape)) # opencv uses (width, height) instead of (height, width) + if heatmap_shape != correct_shape: + heatmap = cv2.resize(heatmap, cv2_size, interpolation=cv2.INTER_LINEAR) - if heatmap_shape != correct_shape: - heatmap = cv2.resize(heatmap, cv2_size, interpolation=cv2.INTER_LINEAR) - - affinity_shape = list(affinity_map.shape) - if affinity_shape != correct_shape: - affinity_map = cv2.resize(affinity_map, cv2_size, interpolation=cv2.INTER_LINEAR) - - pred_parts.append((heatmap, affinity_map)) + heatmaps.append(heatmap) + pred_parts.append(heatmaps) preds = [] for i, (idx, height) in enumerate(zip(split_index, split_heights)): if len(preds) <= idx: preds.append(pred_parts[i]) else: - heatmap, affinity_map = preds[idx] - pred_heatmap = pred_parts[i][0] - pred_affinity = pred_parts[i][1] + heatmaps = preds[idx] + pred_heatmaps = [pred_parts[i][k] for k in range(heatmap_count)] if height < processor.size["height"]: # Cut off padding to get original height - pred_heatmap = pred_heatmap[:height, :] - pred_affinity = pred_affinity[:height, :] + pred_heatmaps = [pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps] - heatmap = np.vstack([heatmap, pred_heatmap]) - affinity_map = np.vstack([affinity_map, pred_affinity]) - preds[idx] = (heatmap, affinity_map) + for k in range(heatmap_count): + heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]]) + preds[idx] = heatmaps assert len(preds) == len(images) + assert all([len(pred) == heatmap_count for pred in preds]) + return preds, orig_sizes + + +def batch_text_detection(images: List, model, processor) -> List[TextDetectionResult]: + preds, orig_sizes = batch_detection(images, model, processor) results = [] for i in range(len(images)): heatmap, affinity_map = preds[i] @@ -98,7 +100,7 @@ def batch_detection(images: List, model, processor) -> List[DetectionResult]: vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes[i]) horizontal_lines = get_horizontal_lines(affinity_map, affinity_size, orig_sizes[i]) - result = DetectionResult( + result = TextDetectionResult( bboxes=bboxes, vertical_lines=vertical_lines, horizontal_lines=horizontal_lines, @@ -112,7 +114,3 @@ def batch_detection(images: List, model, processor) -> List[DetectionResult]: return results - - - - diff --git a/surya/input/processing.py b/surya/input/processing.py index 17787e67..6efbe20c 100644 --- a/surya/input/processing.py +++ b/surya/input/processing.py @@ -88,7 +88,7 @@ def slice_and_pad_poly(image: Image.Image, coordinates, idx): # Extract the polygonal area from the image polygon_image = np.array(image) - polygon_image[mask == 0] = 0 + polygon_image[mask == 0] = settings.RECOGNITION_PAD_VALUE polygon_image = Image.fromarray(polygon_image) rectangle = Image.new('RGB', (bbox[2] - bbox[0], bbox[3] - bbox[1]), 'white') diff --git a/surya/languages.py b/surya/languages.py index 79f65aa2..83667cf8 100644 --- a/surya/languages.py +++ b/surya/languages.py @@ -91,7 +91,7 @@ 'vi': 'Vietnamese', 'xh': 'Xhosa', 'yi': 'Yiddish', - 'zh': 'Chinese' + 'zh': 'Chinese', } LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()} diff --git a/surya/layout.py b/surya/layout.py new file mode 100644 index 00000000..a05ac126 --- /dev/null +++ b/surya/layout.py @@ -0,0 +1,184 @@ +import math +from typing import List, Optional +from PIL import Image +import numpy as np + +from surya.detection import batch_detection +from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes, \ + clean_contained_boxes +from surya.schema import LayoutResult, LayoutBox, TextDetectionResult + + +def compute_integral_image(arr): + return arr.cumsum(axis=0).cumsum(axis=1) + + +def bbox_avg(integral_image, x1, y1, x2, y2): + total = integral_image[y2, x2] + above = integral_image[y1 - 1, x2] if y1 > 0 else 0 + left = integral_image[y2, x1 - 1] if x1 > 0 else 0 + above_left = integral_image[y1 - 1, x1 - 1] if (x1 > 0 and y1 > 0) else 0 + bbox_sum = total - above - left + above_left + bbox_area = (x2 - x1) * (y2 - y1) + if bbox_area == 0: + return 0 + return bbox_sum / bbox_area + + +def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[Image.Image], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]: + logits = np.stack(heatmaps, axis=0) + vertical_line_bboxes = [line for line in detection_result.vertical_lines] + line_bboxes = [line for line in detection_result.bboxes] + + # Scale back to processor size + for line in vertical_line_bboxes: + line.rescale_bbox(orig_size, list(reversed(heatmaps[0].shape))) + + for line in line_bboxes: + line.rescale(orig_size, list(reversed(heatmaps[0].shape))) + + for bbox in vertical_line_bboxes: + # Give some width to the vertical lines + vert_bbox = list(bbox.bbox) + vert_bbox[2] = min(heatmaps[0].shape[0], vert_bbox[2] + vertical_line_width) + + logits[:, vert_bbox[1]:vert_bbox[3], vert_bbox[0]:vert_bbox[2]] = 0 # zero out where the column lines are + + logits[:, logits[0] >= .5] = 0 # zero out where blanks are + + # Zero out where other segments are + for i in range(logits.shape[0]): + logits[i, segment_assignment != i] = 0 + + detected_boxes = [] + done_maps = set() + for iteration in range(100): # detect up to 100 boxes + bbox = None + confidence = None + for heatmap_idx in range(1, len(id2label)): # Skip the blank class + if heatmap_idx in done_maps: + continue + heatmap = logits[heatmap_idx] + bboxes = get_detected_boxes(heatmap, text_threshold=.9) + bboxes = [bbox for bbox in bboxes if bbox.area > 25] + for bb in bboxes: + bb.fit_to_bounds([0, 0, heatmap.shape[1] - 1, heatmap.shape[0] - 1]) + + if len(bboxes) == 0: + done_maps.add(heatmap_idx) + continue + + integral_image = compute_integral_image(heatmap) + bbox_confidences = [bbox_avg(integral_image, *[int(b) for b in bbox.bbox]) for bbox in bboxes] + + max_confidence = max(bbox_confidences) + max_confidence_idx = bbox_confidences.index(max_confidence) + if max_confidence >= .15 and (confidence is None or max_confidence > confidence): + bbox = LayoutBox(polygon=bboxes[max_confidence_idx].polygon, label=id2label[heatmap_idx]) + elif max_confidence < .15: + done_maps.add(heatmap_idx) + + if bbox is None: + break + + # Expand bbox to cover intersecting lines + remove_indices = [] + covered_lines = [] + for line_idx, line_bbox in enumerate(line_bboxes): + if line_bbox.intersection_pct(bbox) >= .5: + remove_indices.append(line_idx) + covered_lines.append(line_bbox.bbox) + + logits[:, int(bbox.bbox[1]):int(bbox.bbox[3]), int(bbox.bbox[0]):int(bbox.bbox[2])] = 0 # zero out where the detected bbox is + if len(covered_lines) == 0 and bbox.label not in ["Picture", "Formula"]: + continue + + if len(covered_lines) > 0 and bbox.label == "Picture": + bbox.label = "Figure" + + if len(covered_lines) > 0 and bbox.label not in ["Picture"]: + min_x = min([line[0] for line in covered_lines]) + min_y = min([line[1] for line in covered_lines]) + max_x = max([line[2] for line in covered_lines]) + max_y = max([line[3] for line in covered_lines]) + + min_x_box = min([b[0] for b in bbox.polygon]) + min_y_box = min([b[1] for b in bbox.polygon]) + max_x_box = max([b[0] for b in bbox.polygon]) + max_y_box = max([b[1] for b in bbox.polygon]) + + min_x = min(min_x, min_x_box) + min_y = min(min_y, min_y_box) + max_x = max(max_x, max_x_box) + max_y = max(max_y, max_y_box) + + bbox.polygon[0][0] = min_x + bbox.polygon[0][1] = min_y + bbox.polygon[1][0] = max_x + bbox.polygon[1][1] = min_y + bbox.polygon[2][0] = max_x + bbox.polygon[2][1] = max_y + bbox.polygon[3][0] = min_x + bbox.polygon[3][1] = max_y + + # Remove "used" overlap lines + line_bboxes = [line_bboxes[i] for i in range(len(line_bboxes)) if i not in remove_indices] + detected_boxes.append(bbox) + + logits[:, int(bbox.bbox[1]):int(bbox.bbox[3]), int(bbox.bbox[0]):int(bbox.bbox[2])] = 0 # zero out where the new box is + + if len(line_bboxes) > 0: + for bbox in line_bboxes: + detected_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text")) + + for bbox in detected_boxes: + bbox.rescale(list(reversed(heatmap.shape)), orig_size) + + detected_boxes = [bbox for bbox in detected_boxes if bbox.area > 16] + detected_boxes = clean_contained_boxes(detected_boxes) + return detected_boxes + + +def get_regions(heatmaps: List[Image.Image], orig_size, id2label, segment_assignment) -> List[LayoutBox]: + bboxes = [] + for i in range(1, len(id2label)): # Skip the blank class + heatmap = heatmaps[i] + assert heatmap.shape == segment_assignment.shape + heatmap[segment_assignment != i] = 0 # zero out where another segment is + bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size, low_text=.7, text_threshold=.8) + for bb in bbox: + bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i])) + heatmaps.append(heatmap) + + bboxes = keep_largest_boxes(bboxes) + return bboxes + + +def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None) -> List[LayoutResult]: + preds, orig_sizes = batch_detection(images, model, processor) + id2label = model.config.id2label + + results = [] + for i in range(len(images)): + heatmaps = preds[i] + orig_size = orig_sizes[i] + logits = np.stack(heatmaps, axis=0) + segment_assignment = logits.argmax(axis=0) + + if detection_results: + bboxes = get_regions_from_detection_result(detection_results[i], heatmaps, orig_size, id2label, segment_assignment) + else: + bboxes = get_regions(heatmaps, orig_size, id2label, segment_assignment) + + segmentation_img = Image.fromarray(segment_assignment.astype(np.uint8)) + + result = LayoutResult( + bboxes=bboxes, + segmentation_map=segmentation_img, + heatmaps=heatmaps, + image_bbox=[0, 0, orig_size[0], orig_size[1]] + ) + + results.append(result) + + return results \ No newline at end of file diff --git a/surya/model/detection/segformer.py b/surya/model/detection/segformer.py index a76d1e44..9fc77635 100644 --- a/surya/model/detection/segformer.py +++ b/surya/model/detection/segformer.py @@ -1,3 +1,4 @@ +import math from typing import Optional, Tuple, Union from transformers import SegformerConfig, SegformerForSemanticSegmentation, SegformerImageProcessor, \ @@ -63,6 +64,39 @@ def __init__(self, config): self.config = config + def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor: + batch_size = encoder_hidden_states[-1].shape[0] + + all_hidden_states = () + for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c): + if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3: + height = width = int(math.sqrt(encoder_hidden_state.shape[-1])) + encoder_hidden_state = ( + encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() + ) + + # unify channel dimension + height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] + encoder_hidden_state = mlp(encoder_hidden_state) + encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) + encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width) + # upsample + encoder_hidden_state = encoder_hidden_state.contiguous() + encoder_hidden_state = nn.functional.interpolate( + encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False + ) + all_hidden_states += (encoder_hidden_state,) + + hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) + hidden_states = self.batch_norm(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(hidden_states) + + # logits are of shape (batch_size, num_labels, height/4, width/4) + logits = self.classifier(hidden_states) + + return logits + class SegformerForRegressionMask(SegformerForSemanticSegmentation): def __init__(self, config): diff --git a/surya/model/recognition/decoder.py b/surya/model/recognition/decoder.py index bb64243e..01befe2f 100644 --- a/surya/model/recognition/decoder.py +++ b/surya/model/recognition/decoder.py @@ -15,9 +15,13 @@ class MBartExpertMLP(nn.Module): - def __init__(self, config: MBartConfig): + def __init__(self, config: MBartConfig, is_lg=False, is_xl=False): super().__init__() self.ffn_dim = config.d_expert + if is_lg: + self.ffn_dim = config.d_expert_lg + if is_xl: + self.ffn_dim = config.d_expert_xl self.hidden_dim = config.d_model self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) @@ -41,10 +45,17 @@ def __init__(self, config): self.dropout = nn.Dropout(config.activation_dropout) self.hidden_dim = config.d_model + + self.lg_lang_codes = [] + self.xl_lang_codes = [] + if hasattr(config, "lg_langs"): + self.lg_lang_codes = sorted(config.lg_langs.values()) + if hasattr(config, "xl_langs"): + self.xl_lang_codes = sorted(config.xl_langs.values()) self.lang_codes = sorted(config.langs.values()) self.num_experts = len(self.lang_codes) - self.experts = nn.ModuleDict({str(lang): MBartExpertMLP(config) for lang in self.lang_codes}) + self.experts = nn.ModuleDict({str(lang): MBartExpertMLP(config, is_lg=(lang in self.lg_lang_codes), is_xl=(lang in self.xl_lang_codes)) for lang in self.lang_codes}) def forward(self, hidden_states: torch.Tensor, langs: torch.LongTensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape diff --git a/surya/ocr.py b/surya/ocr.py index 5344dbf0..c246fc7f 100644 --- a/surya/ocr.py +++ b/surya/ocr.py @@ -5,7 +5,7 @@ import torch from PIL import Image -from surya.detection import batch_detection +from surya.detection import batch_text_detection from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image from surya.postprocessing.text import truncate_repetitions, sort_text_lines from surya.recognition import batch_recognition @@ -27,7 +27,7 @@ def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model all_slices.extend(slices) all_langs.extend([lang] * len(slices)) - rec_predictions = batch_recognition(all_slices, all_langs, rec_model, rec_processor) + rec_predictions, _ = batch_recognition(all_slices, all_langs, rec_model, rec_processor) predictions_by_image = [] slice_start = 0 @@ -60,7 +60,7 @@ def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor) -> List[OCRResult]: - det_predictions = batch_detection(images, det_model, det_processor) + det_predictions = batch_text_detection(images, det_model, det_processor) if det_model.device == "cuda": torch.cuda.empty_cache() # Empty cache from first model run @@ -74,25 +74,25 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr all_slices.extend(slices) all_langs.extend([lang] * len(slices)) - rec_predictions = batch_recognition(all_slices, all_langs, rec_model, rec_processor) + rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor) predictions_by_image = [] slice_start = 0 for idx, (image, det_pred, lang) in enumerate(zip(images, det_predictions, langs)): slice_end = slice_start + slice_map[idx] image_lines = rec_predictions[slice_start:slice_end] + line_confidences = confidence_scores[slice_start:slice_end] slice_start = slice_end assert len(image_lines) == len(det_pred.bboxes) - # Remove repeated characters - image_lines = [truncate_repetitions(l) for l in image_lines] lines = [] - for text_line, bbox in zip(image_lines, det_pred.bboxes): + for text_line, confidence, bbox in zip(image_lines, line_confidences, det_pred.bboxes): lines.append(TextLine( text=text_line, polygon=bbox.polygon, - bbox=bbox.bbox + bbox=bbox.bbox, + confidence=confidence )) lines = sort_text_lines(lines) diff --git a/surya/postprocessing/fonts.py b/surya/postprocessing/fonts.py new file mode 100644 index 00000000..9e309c47 --- /dev/null +++ b/surya/postprocessing/fonts.py @@ -0,0 +1,24 @@ +from typing import List +import os +import requests + +from surya.settings import settings + + +def get_font_path(langs: List[str] | None = None) -> str: + font_path = settings.RECOGNITION_RENDER_FONTS["all"] + if langs is not None: + for k in settings.RECOGNITION_RENDER_FONTS: + if k in langs and len(langs) == 1: + font_path = settings.RECOGNITION_RENDER_FONTS[k] + break + + if not os.path.exists(font_path): + os.makedirs(os.path.dirname(font_path), exist_ok=True) + font_dl_path = f"{settings.RECOGNITION_FONT_DL_BASE}/{os.path.basename(font_path)}" + with requests.get(font_dl_path, stream=True) as r, open(font_path, 'wb') as f: + r.raise_for_status() + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + return font_path \ No newline at end of file diff --git a/surya/postprocessing/heatmap.py b/surya/postprocessing/heatmap.py index 49f75652..3d45e693 100644 --- a/surya/postprocessing/heatmap.py +++ b/surya/postprocessing/heatmap.py @@ -1,15 +1,41 @@ -from typing import List +from typing import List, Tuple import numpy as np import cv2 import math -from PIL import ImageDraw +from PIL import ImageDraw, ImageFont +from surya.postprocessing.fonts import get_font_path from surya.postprocessing.util import rescale_bbox from surya.schema import PolygonBox from surya.settings import settings +def keep_largest_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]: + new_boxes = [] + for box_obj in boxes: + box = box_obj.bbox + box_area = (box[2] - box[0]) * (box[3] - box[1]) + contained = False + for other_box_obj in boxes: + if other_box_obj.polygon == box_obj.polygon: + continue + + other_box = other_box_obj.bbox + other_box_area = (other_box[2] - other_box[0]) * (other_box[3] - other_box[1]) + if box == other_box: + continue + # find overlap percentage + overlap = max(0, min(box[2], other_box[2]) - max(box[0], other_box[0])) * max(0, min(box[3], other_box[3]) - max(box[1], other_box[1])) + overlap = overlap / box_area + if overlap > .9 and box_area < other_box_area: + contained = True + break + if not contained: + new_boxes.append(box_obj) + return new_boxes + + def clean_contained_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]: new_boxes = [] for box_obj in boxes: @@ -64,6 +90,8 @@ def detect_boxes(linemap, text_threshold, low_text): label_count, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4) det = [] + confidences = [] + max_confidence = 0 for k in range(1, label_count): # size filtering size = stats[k, cv2.CC_STAT_AREA] @@ -113,24 +141,45 @@ def detect_boxes(linemap, text_threshold, low_text): box = np.roll(box, 4-startidx, 0) box = np.array(box) + mask = np.zeros_like(linemap).astype(np.uint8) + cv2.fillPoly(mask, [np.int32(box)], 255) + mask = mask.astype(np.float16) / 255 + + roi = np.where(mask == 1, linemap, 0) + confidence = np.mean(roi[roi != 0]) + + if confidence > max_confidence: + max_confidence = confidence + + confidences.append(confidence) det.append(box) - return det, labels + if max_confidence > 0: + confidences = [c / max_confidence for c in confidences] + return det, labels, confidences + +def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]: + if text_threshold is None: + text_threshold = settings.DETECTOR_TEXT_THRESHOLD + + if low_text is None: + low_text = settings.DETECTOR_BLANK_THRESHOLD -def get_detected_boxes(textmap, text_threshold=settings.DETECTOR_TEXT_THRESHOLD, low_text=settings.DETECTOR_BLANK_THRESHOLD) -> List[PolygonBox]: textmap = textmap.copy() textmap = textmap.astype(np.float32) - boxes, labels = detect_boxes(textmap, text_threshold, low_text) + boxes, labels, confidences = detect_boxes(textmap, text_threshold, low_text) # From point form to box form - boxes = [PolygonBox(polygon=box) for box in boxes] + boxes = [PolygonBox(polygon=box, confidence=confidence) for box, confidence in zip(boxes, confidences)] return boxes -def get_and_clean_boxes(textmap, processor_size, image_size) -> List[PolygonBox]: - bboxes = get_detected_boxes(textmap) +def get_and_clean_boxes(textmap, processor_size, image_size, text_threshold=None, low_text=None) -> List[PolygonBox]: + bboxes = get_detected_boxes(textmap, text_threshold, low_text) for bbox in bboxes: bbox.rescale(processor_size, image_size) + bbox.fit_to_bounds([0, 0, image_size[0], image_size[1]]) + bboxes = clean_contained_boxes(bboxes) return bboxes @@ -144,13 +193,20 @@ def draw_bboxes_on_image(bboxes, image): return image -def draw_polys_on_image(corners, image): +def draw_polys_on_image(corners, image, labels=None): draw = ImageDraw.Draw(image) + font_path = get_font_path() + label_font = ImageFont.truetype(font_path, 16) - for poly in corners: - poly = [(p[0], p[1]) for p in poly] + for i in range(len(corners)): + poly = corners[i] + poly = [(int(p[0]), int(p[1])) for p in poly] draw.polygon(poly, outline='red', width=1) + if labels is not None: + label = labels[i] + draw.text((min([p[0] for p in poly]), min([p[1] for p in poly])), label, fill="blue", font=label_font) + return image diff --git a/surya/postprocessing/math/latex.py b/surya/postprocessing/math/latex.py new file mode 100644 index 00000000..b07e5fb8 --- /dev/null +++ b/surya/postprocessing/math/latex.py @@ -0,0 +1,125 @@ +import re +from ftfy import fix_text + + +def contains_math(text): + return text.startswith("$") or text.endswith("$") + + +def fix_math(text): + # Fix any issues with the text + text = fix_text(text) + + # Remove LaTeX labels and references + text = remove_labels(text) + text = replace_katex_invalid(text) + text = fix_fences(text) + return text + + +def remove_labels(text): + pattern = r'\\label\{[^}]*\}' + text = re.sub(pattern, '', text) + + ref_pattern = r'\\ref\{[^}]*\}' + text = re.sub(ref_pattern, '', text) + + pageref_pattern = r'\\pageref\{[^}]*\}' + text = re.sub(pageref_pattern, '', text) + return text + + +def replace_katex_invalid(string): + # KaTeX cannot render all LaTeX, so we need to replace some things + string = re.sub(r'\\tag\{.*?\}', '', string) + string = re.sub(r'\\(?:Bigg?|bigg?)\{(.*?)\}', r'\1', string) + string = re.sub(r'\\quad\\mbox\{(.*?)\}', r'\1', string) + string = re.sub(r'\\mbox\{(.*?)\}', r'\1', string) + string = remove_inner_dollars(string) + return string + + +def remove_inner_dollars(text): + def replace_dollar(match): + # Replace single $ with nothing, keep $$ intact + math_block = match.group(1) + return '$$' + math_block.replace('$', '') + '$$' + + pattern = r'\$\$(.*?)\$\$' + return re.sub(pattern, replace_dollar, text, flags=re.DOTALL) + + +def extract_latex_with_positions(text): + pattern = r'(\$\$.*?\$\$|\$.*?\$)' + matches = [] + for match in re.finditer(pattern, text, re.DOTALL): + matches.append((match.group(), match.start(), match.end())) + return matches + + +def slice_latex(text): + # Extract LaTeX blocks along with their positions + latex_blocks_with_positions = extract_latex_with_positions(text) + + chunks = [] + last_position = 0 + for block, start, end in latex_blocks_with_positions: + # Add text before the current LaTeX block, if any + if start > last_position: + chunks.append({"text": text[last_position:start], "type": "text"}) + # Add the LaTeX block + chunks.append({"text": block, "type": "latex"}) + last_position = end + # Add remaining text after the last LaTeX block, if any + if last_position < len(text): + chunks.append({"text": text[last_position:], "type": "text"}) + + return chunks + + +def is_latex(text): + latex_patterns = [ + r'\\(?:begin|end)\{[a-zA-Z]*\}', + r'\$.*?\$', + r'\$\$.*?\$\$', + r'\\[a-zA-Z]+', + r'\\[^a-zA-Z]', + ] + + combined_pattern = '|'.join(latex_patterns) + if re.search(combined_pattern, text, re.DOTALL): + return True + + return False + + +def fix_fences(text): + if text.startswith("$$") and not text.endswith("$$"): + if text[-1] == "$": + text += "$" + else: + text += "$$" + + if text.endswith("$$") and not text.startswith("$$"): + if text[0] == "$": + text = "$" + text + else: + text = "$$" + text + + if text.startswith("$") and not text.endswith("$"): + text = "$" + text + "$$" + + if text.endswith("$") and not text.startswith("$"): + text = "$$" + text + "$" + + return text + + +def strip_fences(text): + while text.startswith("$"): + text = text[1:] + while text.endswith("$"): + text = text[:-1] + return text + + diff --git a/surya/postprocessing/math/render.py b/surya/postprocessing/math/render.py new file mode 100644 index 00000000..761334a0 --- /dev/null +++ b/surya/postprocessing/math/render.py @@ -0,0 +1,88 @@ +from playwright.sync_api import sync_playwright +from PIL import Image +import io + + +def latex_to_pil(latex_code, target_width, target_height, fontsize=18): + html_template = """ + + + + + + + + +
{content}
+ + + + """ + + formatted_latex = latex_code.replace('\n', '\\n').replace('"', '\\"') + with sync_playwright() as p: + browser = p.chromium.launch() + page = browser.new_page() + page.set_viewport_size({'width': target_width, 'height': target_height}) + + while fontsize <= 30: + html_content = html_template.replace("{content}", formatted_latex).replace("{fontsize}", str(fontsize)) + page.set_content(html_content) + + dimensions = page.evaluate("""() => { + const render = document.getElementById('content'); + return { + width: render.offsetWidth, + height: render.offsetHeight + }; + }""") + + if dimensions['width'] >= target_width or dimensions['height'] >= target_height: + fontsize -= 1 + break + else: + fontsize += 1 + + html_content = html_template.replace("{content}", formatted_latex).replace("{fontsize}", str(fontsize)) + page.set_content(html_content) + + screenshot_bytes = page.screenshot() + browser.close() + + image_stream = io.BytesIO(screenshot_bytes) + pil_image = Image.open(image_stream) + pil_image.load() + return pil_image \ No newline at end of file diff --git a/surya/postprocessing/text.py b/surya/postprocessing/text.py index ab868643..52fe5513 100644 --- a/surya/postprocessing/text.py +++ b/surya/postprocessing/text.py @@ -1,11 +1,14 @@ import os -from typing import List +from typing import List, Tuple import requests from PIL import Image, ImageDraw, ImageFont +from surya.postprocessing.fonts import get_font_path from surya.schema import TextLine from surya.settings import settings +from surya.postprocessing.math.latex import is_latex +from surya.postprocessing.math.render import latex_to_pil def sort_text_lines(lines: List[TextLine], tolerance=1.25): @@ -65,39 +68,51 @@ def get_text_size(text, font): return width, height -def draw_text_on_image(bboxes, texts, image_size=(1024, 1024), font_path=settings.RECOGNITION_RENDER_FONT, max_font_size=60, res_upscale=2): +def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size): + font = ImageFont.truetype(font_path, box_font_size) + text_width, text_height = get_text_size(text, font) + while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6: + box_font_size = box_font_size - 1 + font = ImageFont.truetype(font_path, box_font_size) + text_width, text_height = get_text_size(text, font) + + # Calculate text position (centered in bbox) + text_width, text_height = get_text_size(text, font) + x = s_bbox[0] + y = s_bbox[1] + (bbox_height - text_height) / 2 + + draw.text((x, y), text, fill="black", font=font) + + +def render_math(image, draw, text, s_bbox, bbox_width, bbox_height, font_path): + try: + box_font_size = max(10, min(int(.2 * bbox_height), 24)) + img = latex_to_pil(text, bbox_width, bbox_height, fontsize=box_font_size) + img.thumbnail((bbox_width, bbox_height)) + image.paste(img, (s_bbox[0], s_bbox[1])) + except Exception as e: + print(f"Failed to render math: {e}") + box_font_size = max(10, min(int(.75 * bbox_height), 24)) + render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size) + + +def draw_text_on_image(bboxes, texts, image_size: Tuple[int, int], langs: List[str], font_path=None, max_font_size=60, res_upscale=2, has_math=False): + if font_path is None: + font_path = get_font_path(langs) new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale) image = Image.new('RGB', new_image_size, color='white') draw = ImageDraw.Draw(image) for bbox, text in zip(bboxes, texts): - s_bbox = [coord * res_upscale for coord in bbox] + s_bbox = [int(coord * res_upscale) for coord in bbox] bbox_width = s_bbox[2] - s_bbox[0] bbox_height = s_bbox[3] - s_bbox[1] # Shrink the text to fit in the bbox if needed - box_font_size = max(6, min(int(.75 * bbox_height), max_font_size)) - - # Download font if it doesn't exist - if not os.path.exists(font_path): - os.makedirs(os.path.dirname(font_path), exist_ok=True) - with requests.get(settings.RECOGNITION_FONT_DL_PATH, stream=True) as r, open(font_path, 'wb') as f: - r.raise_for_status() - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - - font = ImageFont.truetype(font_path, box_font_size) - text_width, text_height = get_text_size(text, font) - while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6: - box_font_size = box_font_size - 1 - font = ImageFont.truetype(font_path, box_font_size) - text_width, text_height = get_text_size(text, font) - - # Calculate text position (centered in bbox) - text_width, text_height = get_text_size(text, font) - x = s_bbox[0] - y = s_bbox[1] + (bbox_height - text_height) / 2 - - draw.text((x, y), text, fill="black", font=font) + if has_math and is_latex(text): + render_math(image, draw, text, s_bbox, bbox_width, bbox_height, font_path) + else: + box_font_size = max(6, min(int(.75 * bbox_height), max_font_size)) + render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size) return image diff --git a/surya/recognition.py b/surya/recognition.py index 113e9612..5cd9dcd3 100644 --- a/surya/recognition.py +++ b/surya/recognition.py @@ -1,9 +1,13 @@ from typing import List import torch from PIL import Image + +from surya.postprocessing.math.latex import fix_math, contains_math +from surya.postprocessing.text import truncate_repetitions from surya.settings import settings from tqdm import tqdm import numpy as np +import torch.nn.functional as F def get_batch_size(): @@ -25,8 +29,10 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor images = [image.convert("RGB") for image in images] output_text = [] + confidences = [] for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"): batch_langs = languages[i:i+batch_size] + has_math = ["_math" in lang for lang in batch_langs] batch_images = images[i:i+batch_size] model_inputs = processor(text=[""] * len(batch_langs), images=batch_images, lang=batch_langs) @@ -39,17 +45,45 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor batch_decoder_input = torch.from_numpy(np.array(batch_decoder_input, dtype=np.int64)).to(model.device) with torch.inference_mode(): - generated_ids = model.generate( + return_dict = model.generate( pixel_values=batch_pixel_values, decoder_input_ids=batch_decoder_input, decoder_langs=batch_langs, eos_token_id=processor.tokenizer.eos_id, - max_new_tokens=settings.RECOGNITION_MAX_TOKENS + max_new_tokens=settings.RECOGNITION_MAX_TOKENS, + output_scores=True, + return_dict_in_generate=True ) - - output_text.extend(processor.tokenizer.batch_decode(generated_ids)) - - return output_text + generated_ids = return_dict["sequences"] + + # Find confidence scores + scores = return_dict["scores"] # Scores is a tuple, one per new sequence position. Each tuple element is bs x vocab_size + sequence_scores = torch.zeros(generated_ids.shape[0]) + sequence_lens = torch.where( + generated_ids > processor.tokenizer.eos_id, + torch.ones_like(generated_ids), + torch.zeros_like(generated_ids) + ).sum(axis=-1).cpu() + prefix_len = generated_ids.shape[1] - len(scores) # Length of passed in tokens (bos, langs) + for token_idx, score in enumerate(scores): + probs = F.softmax(score, dim=-1) + max_probs = torch.max(probs, dim=-1).values + max_probs = torch.where( + generated_ids[:, token_idx + prefix_len] <= processor.tokenizer.eos_id, + torch.zeros_like(max_probs), + max_probs + ).cpu() + sequence_scores += max_probs + sequence_scores /= sequence_lens + + detected_text = processor.tokenizer.batch_decode(generated_ids) + detected_text = [truncate_repetitions(dt) for dt in detected_text] + # Postprocess to fix LaTeX output (add $$ signs, etc) + detected_text = [fix_math(text) if math and contains_math(text) else text for text, math in zip(detected_text, has_math)] + output_text.extend(detected_text) + confidences.extend(sequence_scores.tolist()) + + return output_text, confidences diff --git a/surya/schema.py b/surya/schema.py index 35d8c6bb..2261b88b 100644 --- a/surya/schema.py +++ b/surya/schema.py @@ -1,5 +1,5 @@ import copy -from typing import List, Tuple, Any +from typing import List, Tuple, Any, Optional from pydantic import BaseModel, field_validator, computed_field @@ -8,6 +8,7 @@ class PolygonBox(BaseModel): polygon: List[List[float]] + confidence: Optional[float] = None @field_validator('polygon') @classmethod @@ -22,11 +23,11 @@ def check_elements(cls, v: List[List[float]]) -> List[List[float]]: @property def height(self): - return self.polygon[1][1] - self.polygon[0][1] + return self.bbox[3] - self.bbox[1] @property def width(self): - return self.polygon[1][0] - self.polygon[0][0] + return self.bbox[2] - self.bbox[0] @property def area(self): @@ -56,6 +57,24 @@ def rescale(self, processor_size, image_size): corner[1] = int(corner[1] * height_scaler) self.polygon = new_corners + def fit_to_bounds(self, bounds): + new_corners = copy.deepcopy(self.polygon) + for corner in new_corners: + corner[0] = max(min(corner[0], bounds[2]), bounds[0]) + corner[1] = max(min(corner[1], bounds[3]), bounds[1]) + self.polygon = new_corners + + def intersection_area(self, other): + x_overlap = max(0, min(self.bbox[2], other.bbox[2]) - max(self.bbox[0], other.bbox[0])) + y_overlap = max(0, min(self.bbox[3], other.bbox[3]) - max(self.bbox[1], other.bbox[1])) + return x_overlap * y_overlap + + def intersection_pct(self, other): + if self.area == 0: + return 0 + + intersection = self.intersection_area(other) + return intersection / self.area class Bbox(BaseModel): @@ -87,6 +106,10 @@ def area(self): return self.width * self.height +class LayoutBox(PolygonBox): + label: str + + class ColumnLine(Bbox): vertical: bool horizontal: bool @@ -94,6 +117,7 @@ class ColumnLine(Bbox): class TextLine(PolygonBox): text: str + confidence: Optional[float] = None class OCRResult(BaseModel): @@ -102,10 +126,16 @@ class OCRResult(BaseModel): image_bbox: List[float] -class DetectionResult(BaseModel): +class TextDetectionResult(BaseModel): bboxes: List[PolygonBox] vertical_lines: List[ColumnLine] horizontal_lines: List[ColumnLine] heatmap: Any affinity_map: Any image_bbox: List[float] + + +class LayoutResult(BaseModel): + bboxes: List[LayoutBox] + segmentation_map: Any + image_bbox: List[float] diff --git a/surya/settings.py b/surya/settings.py index b0dcdf79..7a542424 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -44,20 +44,30 @@ def TORCH_DEVICE_DETECTION(self) -> str: # Text detection DETECTOR_BATCH_SIZE: Optional[int] = None # Defaults to 2 for CPU, 32 otherwise - DETECTOR_MODEL_CHECKPOINT: str = "vikp/surya_det" + DETECTOR_MODEL_CHECKPOINT: str = "vikp/surya_det2" + DETECTOR_MATH_MODEL_CHECKPOINT: str = "vikp/surya_det_math" DETECTOR_BENCH_DATASET_NAME: str = "vikp/doclaynet_bench" - DETECTOR_IMAGE_CHUNK_HEIGHT: int = 1280 # Height at which to slice images vertically + DETECTOR_IMAGE_CHUNK_HEIGHT: int = 1400 # Height at which to slice images vertically DETECTOR_TEXT_THRESHOLD: float = 0.6 # Threshold for text detection (above this is considered text) DETECTOR_BLANK_THRESHOLD: float = 0.35 # Threshold for blank space (below this is considered blank) # Text recognition RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec" - RECOGNITION_MAX_TOKENS: int = 160 + RECOGNITION_MAX_TOKENS: int = 175 RECOGNITION_BATCH_SIZE: Optional[int] = None # Defaults to 8 for CPU/MPS, 256 otherwise RECOGNITION_IMAGE_SIZE: Dict = {"height": 196, "width": 896} - RECOGNITION_RENDER_FONT: str = os.path.join(FONT_DIR, "GoNotoKurrent-Regular.ttf") - RECOGNITION_FONT_DL_PATH: str = "https://github.com/satbyy/go-noto-universal/releases/download/v7.0/GoNotoKurrent-Regular.ttf" + RECOGNITION_RENDER_FONTS: Dict[str, str] = { + "all": os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf"), + "zh": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), + "ja": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), + "ko": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), + } + RECOGNITION_FONT_DL_BASE: str = "https://github.com/satbyy/go-noto-universal/releases/download/v7.0" RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench" + RECOGNITION_PAD_VALUE: int = 0 # Should be 0 or 255 + + # Layout + LAYOUT_MODEL_CHECKPOINT: str = "vikp/surya_layout" # Tesseract (for benchmarks only) TESSDATA_PREFIX: Optional[str] = None