From 4d8e28eebb1043e17830ea80fe6d2b52269d2d58 Mon Sep 17 00:00:00 2001 From: viklofg Date: Thu, 4 Apr 2024 14:09:13 +0000 Subject: [PATCH] New postprocess function label_regions The new function uses the new printspace estimator to classify each region as either margin or printspace. It is right now implemented as a volume postprocess method. --- src/htrflow_core/utils/geometry.py | 26 ++++++++++++++++++++++++++ src/htrflow_core/volume/postprocess.py | 20 ++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/src/htrflow_core/utils/geometry.py b/src/htrflow_core/utils/geometry.py index 1957f18..8bd20cc 100644 --- a/src/htrflow_core/utils/geometry.py +++ b/src/htrflow_core/utils/geometry.py @@ -363,3 +363,29 @@ def is_twopage(img, strip_width=0.1, threshold=0.2): if np.min(strip) < np.sort(levels)[int(w * threshold)]: return middle - half_strip + np.argmin(strip) return None + + +class RegionLocation: + PRINTSPACE = "printspace" + MARGIN_LEFT = "margin_left" + MARGIN_RIGHT = "margin_right" + MARGIN_TOP = "margin_top" + MARGIN_BOTTOM = "margin_bottom" + + +def get_region_location(printspace: Bbox, region: Bbox) -> RegionLocation: + """Get location of `region` relative to `printspace` + + The side margins extends to the top and bottom of the page. If the + region is located in a corner, it will be assigned to the left or + right margin and not the top or bottom margin. + """ + if region.center.x < printspace.xmin: + return RegionLocation.MARGIN_LEFT + elif region.center.x > printspace.xmax: + return RegionLocation.MARGIN_RIGHT + elif region.center.y > printspace.ymax: + return RegionLocation.MARGIN_TOP + elif region.center.y < printspace.ymin: + return RegionLocation.MARGIN_BOTTOM + return RegionLocation.PRINTSPACE diff --git a/src/htrflow_core/volume/postprocess.py b/src/htrflow_core/volume/postprocess.py index 9db3135..845acf9 100644 --- a/src/htrflow_core/volume/postprocess.py +++ b/src/htrflow_core/volume/postprocess.py @@ -1,5 +1,6 @@ from copy import deepcopy +from htrflow_core.utils.geometry import estimate_printspace, get_region_location from htrflow_core.volume import volume @@ -43,3 +44,22 @@ def is_noise(node: volume.BaseDocumentNode, threshold: float = 0.8): conf = sum(child.get("text_result").top_score() for child in node) / len(node.children) return conf < threshold return False + + +def label_regions(volume: volume.Volume, key="region_location"): + """Label volume's regions + + Labels each top-level segment of the volume as one of the five + region types specified by geometry.RegionLocation. Saves the label + in the node's data dictionary under `key`. + + Arguments: + volume: Input volume + key: Key used to save the region label. Defaults to + "region_location". + """ + + for page in volume: + printspace = estimate_printspace(page.image) + for node in page: + node.add_data(**{key: get_region_location(printspace, node.bbox)})