Skip to content

Commit

Permalink
New postprocess function label_regions
Browse files Browse the repository at this point in the history
The new function uses the new printspace estimator to classify each
region as either margin or printspace. It is right now implemented as a
volume postprocess method.
  • Loading branch information
viklofg committed Apr 4, 2024
1 parent 44f1482 commit 4d8e28e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/htrflow_core/utils/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,29 @@ def is_twopage(img, strip_width=0.1, threshold=0.2):
if np.min(strip) < np.sort(levels)[int(w * threshold)]:
return middle - half_strip + np.argmin(strip)
return None


class RegionLocation:
PRINTSPACE = "printspace"
MARGIN_LEFT = "margin_left"
MARGIN_RIGHT = "margin_right"
MARGIN_TOP = "margin_top"
MARGIN_BOTTOM = "margin_bottom"


def get_region_location(printspace: Bbox, region: Bbox) -> RegionLocation:
"""Get location of `region` relative to `printspace`
The side margins extends to the top and bottom of the page. If the
region is located in a corner, it will be assigned to the left or
right margin and not the top or bottom margin.
"""
if region.center.x < printspace.xmin:
return RegionLocation.MARGIN_LEFT
elif region.center.x > printspace.xmax:
return RegionLocation.MARGIN_RIGHT
elif region.center.y > printspace.ymax:
return RegionLocation.MARGIN_TOP
elif region.center.y < printspace.ymin:
return RegionLocation.MARGIN_BOTTOM
return RegionLocation.PRINTSPACE
20 changes: 20 additions & 0 deletions src/htrflow_core/volume/postprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy

from htrflow_core.utils.geometry import estimate_printspace, get_region_location
from htrflow_core.volume import volume


Expand Down Expand Up @@ -43,3 +44,22 @@ def is_noise(node: volume.BaseDocumentNode, threshold: float = 0.8):
conf = sum(child.get("text_result").top_score() for child in node) / len(node.children)
return conf < threshold
return False


def label_regions(volume: volume.Volume, key="region_location"):
"""Label volume's regions
Labels each top-level segment of the volume as one of the five
region types specified by geometry.RegionLocation. Saves the label
in the node's data dictionary under `key`.
Arguments:
volume: Input volume
key: Key used to save the region label. Defaults to
"region_location".
"""

for page in volume:
printspace = estimate_printspace(page.image)
for node in page:
node.add_data(**{key: get_region_location(printspace, node.bbox)})

0 comments on commit 4d8e28e

Please sign in to comment.