diff --git a/detect_text.py b/detect_text.py index 98f803e8..3821b6dd 100644 --- a/detect_text.py +++ b/detect_text.py @@ -62,7 +62,7 @@ def main(): predictions_by_page[name].append(pred) with open(os.path.join(result_path, "results.json"), "w+") as f: - json.dump(predictions_by_page, f) + json.dump(predictions_by_page, f, ensure_ascii=False) print(f"Wrote results to {result_path}") diff --git a/ocr_text.py b/ocr_text.py index edd576a0..cdfa54e7 100644 --- a/ocr_text.py +++ b/ocr_text.py @@ -70,7 +70,7 @@ def main(): page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png")) with open(os.path.join(result_path, "results.json"), "w+") as f: - json.dump(predictions_by_image, f) + json.dump(predictions_by_image, f, ensure_ascii=False) print(f"Wrote results to {result_path}") diff --git a/surya/ocr.py b/surya/ocr.py index 3c04f04d..f4625ca7 100644 --- a/surya/ocr.py +++ b/surya/ocr.py @@ -7,6 +7,7 @@ from surya.detection import batch_detection from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image +from surya.postprocessing.text import truncate_repetitions from surya.recognition import batch_recognition @@ -73,6 +74,9 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr slice_start = slice_end assert len(image_lines) == len(det_pred["polygons"]) == len(det_pred["bboxes"]) + + # Remove repeated characters + image_lines = [truncate_repetitions(l) for l in image_lines] predictions_by_image.append({ "text_lines": image_lines, "polys": det_pred["polygons"], diff --git a/surya/postprocessing/text.py b/surya/postprocessing/text.py index 0c0b0924..ea71d873 100644 --- a/surya/postprocessing/text.py +++ b/surya/postprocessing/text.py @@ -5,6 +5,37 @@ from surya.settings import settings +def truncate_repetitions(text: str, min_len=15): + # From nougat, with some cleanup + if len(text) < 2 * min_len: + return text + + # try to find a length at which the tail is repeating + max_rep_len = None + for rep_len in range(min_len, int(len(text) / 2)): + # check if there is a repetition at the end + same = True + for i in range(0, rep_len): + if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]: + same = False + break + + if same: + max_rep_len = rep_len + + if max_rep_len is None: + return text + + lcs = text[-max_rep_len:] + + # remove all but the last repetition + text_to_truncate = text + while text_to_truncate.endswith(lcs): + text_to_truncate = text_to_truncate[:-max_rep_len] + + return text[:len(text_to_truncate)] + + def get_text_size(text, font): im = Image.new(mode="P", size=(0, 0)) draw = ImageDraw.Draw(im)