Skip to content

Commit

Permalink
Add faster precision calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Mar 25, 2024
1 parent daf4e6d commit 2751dcf
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
2 changes: 1 addition & 1 deletion benchmark/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def main():

print(tabulate(table_data, headers=table_headers, tablefmt="github"))
print(f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total.")
print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines.")
print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold.")
print(f"Wrote results to {result_path}")


Expand Down
21 changes: 19 additions & 2 deletions surya/benchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ def calculate_coverage(box, other_boxes, penalize_double=False):
return covered_pixels_count / box_area


def calculate_coverage_fast(box, other_boxes, penalize_double=False):
box_area = (box[2] - box[0]) * (box[3] - box[1])
if box_area == 0:
return 0

total_intersect = 0
for other_box in other_boxes:
total_intersect += intersection_area(box, other_box)

return min(1, total_intersect / box_area)


def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True):
if len(references) == 0:
return {
Expand All @@ -68,10 +80,15 @@ def precision_recall(preds, references, threshold=.5, workers=8, penalize_double
"recall": 0,
}

# If we're not penalizing double coverage, we can use a faster calculation
coverage_func = calculate_coverage_fast
if penalize_double:
coverage_func = calculate_coverage

with ProcessPoolExecutor(max_workers=workers) as executor:
precision_func = partial(calculate_coverage, penalize_double=penalize_double)
precision_func = partial(coverage_func, penalize_double=penalize_double)
precision_iou = executor.map(precision_func, preds, repeat(references))
reference_iou = executor.map(calculate_coverage, references, repeat(preds))
reference_iou = executor.map(coverage_func, references, repeat(preds))

precision_classes = [1 if i > threshold else 0 for i in precision_iou]
precision = sum(precision_classes) / len(precision_classes)
Expand Down
4 changes: 2 additions & 2 deletions surya/postprocessing/fonts.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import List
from typing import List, Optional
import os
import requests

from surya.settings import settings


def get_font_path(langs: List[str] | None = None) -> str:
def get_font_path(langs: Optional[List[str]] = None) -> str:
font_path = settings.RECOGNITION_RENDER_FONTS["all"]
if langs is not None:
for k in settings.RECOGNITION_RENDER_FONTS:
Expand Down
2 changes: 1 addition & 1 deletion surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def TORCH_DEVICE_DETECTION(self) -> str:
DETECTOR_IMAGE_CHUNK_HEIGHT: int = 1400 # Height at which to slice images vertically
DETECTOR_TEXT_THRESHOLD: float = 0.6 # Threshold for text detection (above this is considered text)
DETECTOR_BLANK_THRESHOLD: float = 0.35 # Threshold for blank space (below this is considered blank)
DETECTOR_POSTPROCESSING_CPU_WORKERS: int = 8 # Number of workers for postprocessing
DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min(8, os.cpu_count()) # Number of workers for postprocessing

# Text recognition
RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec"
Expand Down

0 comments on commit 2751dcf

Please sign in to comment.