Skip to content

Commit

Permalink
Refactor to use pydantic, add in sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Feb 16, 2024
1 parent f730955 commit e5c2946
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 115 deletions.
4 changes: 2 additions & 2 deletions benchmark/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def main():

page_metrics = collections.OrderedDict()
for idx, (tb, sb, cb) in enumerate(zip(tess_predictions, predictions, correct_boxes)):
surya_boxes = sb["bboxes"]
surya_polys = sb["polygons"]
surya_boxes = [s.bbox for s in sb.bboxes]
surya_polys = [s.polygon for s in sb.bboxes]

surya_metrics = precision_recall(surya_boxes, cb)
tess_metrics = precision_recall(tb, cb)
Expand Down
6 changes: 4 additions & 2 deletions benchmark/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def main():
surya_scores = defaultdict(list)
img_surya_scores = []
for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, lang_list)):
image_score = overlap_score(pred["text_lines"], ref_text)
pred_text = [l.text for l in pred.text_lines]
image_score = overlap_score(pred_text, ref_text)
img_surya_scores.append(image_score)
for l in lang:
surya_scores[CODE_TO_LANGUAGE[l]].append(image_score)
Expand Down Expand Up @@ -146,7 +147,8 @@ def main():
for idx, (image, pred, ref_text, bbox, lang) in enumerate(zip(images, predictions_by_image, line_text, bboxes, lang_list)):
pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png"
ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png"
pred_image = draw_text_on_image(bbox, pred["text_lines"], image.size)
pred_text = [l.text for l in pred.text_lines]
pred_image = draw_text_on_image(bbox, pred_text, image.size)
pred_image.save(os.path.join(result_path, pred_image_name))
ref_image = draw_text_on_image(bbox, ref_text, image.size)
ref_image.save(os.path.join(result_path, ref_image_name))
Expand Down
16 changes: 7 additions & 9 deletions detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def main():

if args.images:
for idx, (image, pred, name) in enumerate(zip(images, predictions, names)):
bbox_image = draw_polys_on_image(pred["polygons"], copy.deepcopy(image))
polygons = [p.polygon for p in pred.bboxes]
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image))
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_bbox.png"))

column_image = draw_lines_on_image(pred["vertical_lines"], copy.deepcopy(image))
vertical_lines = [l.bbox for l in pred.vertical_lines]
column_image = draw_lines_on_image(vertical_lines, copy.deepcopy(image))
column_image.save(os.path.join(result_path, f"{name}_{idx}_column.png"))

if args.debug:
Expand All @@ -51,15 +53,11 @@ def main():
affinity_map = pred["affinity_map"]
affinity_map.save(os.path.join(result_path, f"{name}_{idx}_affinity.png"))

# Remove all the images from the predictions
for pred in predictions:
pred.pop("heatmap", None)
pred.pop("affinity_map", None)

predictions_by_page = defaultdict(list)
for idx, (pred, name) in enumerate(zip(predictions, names)):
pred["page_number"] = len(predictions_by_page[name]) + 1
predictions_by_page[name].append(pred)
out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"])
out_pred["page_number"] = len(predictions_by_page[name]) + 1
predictions_by_page[name].append(out_pred)

with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(predictions_by_page, f, ensure_ascii=False)
Expand Down
31 changes: 18 additions & 13 deletions ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from PIL import Image
from surya.languages import CODE_TO_LANGUAGE
from surya.input.langs import replace_lang_with_code
from surya.schema import OCRResult, DetectionResult


@st.cache_resource()
Expand All @@ -24,18 +25,22 @@ def load_rec_cached():
return load_rec_model(), load_rec_processor()


def text_detection(img):
preds = batch_detection([img], det_model, det_processor)[0]
det_img = draw_polys_on_image(preds["polygons"], img.copy())
return det_img, preds
def text_detection(img) -> DetectionResult:
pred = batch_detection([img], det_model, det_processor)[0]
polygons = [p.polygon for p in pred.bboxes]
det_img = draw_polys_on_image(polygons, img.copy())
return det_img, pred


# Function for OCR
def ocr(img, langs):
def ocr(img, langs) -> OCRResult:
replace_lang_with_code(langs)
pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0]
rec_img = draw_text_on_image(pred["bboxes"], pred["text_lines"], img.size)
return rec_img, pred
img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0]

bboxes = [l.bbox for l in img_pred.text_lines]
text = [l.text for l in img_pred.text_lines]
rec_img = draw_text_on_image(bboxes, text, img.size)
return rec_img, img_pred


def open_pdf(pdf_file):
Expand Down Expand Up @@ -104,21 +109,21 @@ def page_count(pdf_file):

# Run Text Detection
if text_det and pil_image is not None:
det_img, preds = text_detection(pil_image)
det_img, pred = text_detection(pil_image)
with col1:
st.image(det_img, caption="Detected Text", use_column_width=True)
st.json(preds, expanded=True)
st.json(pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True)

# Run OCR
if text_rec and pil_image is not None:
rec_img, pred = ocr(pil_image, languages)
with col1:
st.image(rec_img, caption="OCR Result", use_column_width=True)
json_tab, text_tab = st.tabs(["JSON", "Full Text"])
json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"])
with json_tab:
st.json(pred, expanded=True)
st.json(pred.model_dump(), expanded=True)
with text_tab:
st.text("\n".join(pred["text_lines"]))
st.text("\n".join([p.text for p in pred.text_lines]))

with col2:
st.image(pil_image, caption="Uploaded Image", use_column_width=True)
21 changes: 13 additions & 8 deletions ocr_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,24 @@ def main():

predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor)

page_num = defaultdict(int)
for i, pred in enumerate(predictions_by_image):
pred["name"] = names[i]
pred["page"] = page_num[names[i]]
page_num[names[i]] += 1

if args.images:
for idx, (name, image, pred) in enumerate(zip(names, images, predictions_by_image)):
page_image = draw_text_on_image(pred["bboxes"], pred["text_lines"], image.size)
bboxes = [l.bbox for l in pred.text_lines]
pred_text = [l.text for l in pred.text_lines]
page_image = draw_text_on_image(bboxes, pred_text, image.size)
page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png"))

out_preds = []
page_num = defaultdict(int)
for i, pred in enumerate(predictions_by_image):
out_pred = pred.model_dump()
out_pred["name"] = names[i]
out_pred["page"] = page_num[names[i]]
page_num[names[i]] += 1
out_preds.append(out_pred)

with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(predictions_by_image, f, ensure_ascii=False)
json.dump(out_preds, f, ensure_ascii=False)

print(f"Wrote results to {result_path}")

Expand Down
22 changes: 12 additions & 10 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from surya.postprocessing.heatmap import get_and_clean_boxes
from surya.postprocessing.affinity import get_vertical_lines, get_horizontal_lines
from surya.input.processing import prepare_image, split_image
from surya.schema import DetectionResult
from surya.settings import settings
from tqdm import tqdm

Expand All @@ -20,7 +21,7 @@ def get_batch_size():
return batch_size


def batch_detection(images: List, model, processor):
def batch_detection(images: List, model, processor) -> List[DetectionResult]:
assert all([isinstance(image, Image.Image) for image in images])
batch_size = get_batch_size()

Expand Down Expand Up @@ -94,18 +95,19 @@ def batch_detection(images: List, model, processor):
affinity_size = list(reversed(affinity_map.shape))
heatmap_size = list(reversed(heatmap.shape))
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes[i])
bbox_data = [bbox.model_dump() for bbox in bboxes]
vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes[i])
horizontal_lines = get_horizontal_lines(affinity_map, affinity_size, orig_sizes[i])

results.append({
"bboxes": [bbd["bbox"] for bbd in bbox_data],
"polygons": [bbd["corners"] for bbd in bbox_data],
"vertical_lines": vertical_lines,
"horizontal_lines": horizontal_lines,
"heatmap": heat_img,
"affinity_map": aff_img,
})
result = DetectionResult(
bboxes=bboxes,
vertical_lines=vertical_lines,
horizontal_lines=horizontal_lines,
heatmap=heat_img,
affinity_map=aff_img

)

results.append(result)

return results

Expand Down
59 changes: 38 additions & 21 deletions surya/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

from surya.detection import batch_detection
from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image
from surya.postprocessing.text import truncate_repetitions
from surya.postprocessing.text import truncate_repetitions, sort_text_lines
from surya.recognition import batch_recognition
from surya.schema import TextLine, OCRResult


def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None):
def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None) -> List[OCRResult]:
# Polygons need to be in corner format - [[x1, y1], [x2, y2], [x3, y3], [x4, y4]], bboxes in [x1, y1, x2, y2] format
assert bboxes is not None or polygons is not None
slice_map = []
Expand All @@ -35,22 +36,29 @@ def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model
image_lines = rec_predictions[slice_start:slice_end]
slice_start = slice_end

pred = {
"text_lines": image_lines,
"language": lang
}

if polygons is not None:
pred["polys"] = polygons[idx]
else:
pred["bboxes"] = bboxes[idx]

text_lines = []
for i in range(len(image_lines)):
if polygons is not None:
poly = polygons[idx][i]
else:
bbox = bboxes[idx][i]
poly = [[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]]

text_lines.append(TextLine(
text=image_lines[i],
polygon=poly
))

pred = OCRResult(
text_lines=text_lines,
languages=lang
)
predictions_by_image.append(pred)

return predictions_by_image


def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor):
def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor) -> List[OCRResult]:
det_predictions = batch_detection(images, det_model, det_processor)
if det_model.device == "cuda":
torch.cuda.empty_cache() # Empty cache from first model run
Expand All @@ -59,7 +67,8 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr
all_slices = []
all_langs = []
for idx, (image, det_pred, lang) in enumerate(zip(images, det_predictions, langs)):
slices = slice_polys_from_image(image, det_pred["polygons"])
polygons = [p.polygon for p in det_pred.bboxes]
slices = slice_polys_from_image(image, polygons)
slice_map.append(len(slices))
all_slices.extend(slices)
all_langs.extend([lang] * len(slices))
Expand All @@ -73,15 +82,23 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr
image_lines = rec_predictions[slice_start:slice_end]
slice_start = slice_end

assert len(image_lines) == len(det_pred["polygons"]) == len(det_pred["bboxes"])
assert len(image_lines) == len(det_pred.bboxes)

# Remove repeated characters
image_lines = [truncate_repetitions(l) for l in image_lines]
predictions_by_image.append({
"text_lines": image_lines,
"polys": det_pred["polygons"],
"bboxes": det_pred["bboxes"],
"language": lang
})
lines = []
for text_line, bbox in zip(image_lines, det_pred.bboxes):
lines.append(TextLine(
text=text_line,
polygon=bbox.polygon,
bbox=bbox.bbox
))

lines = sort_text_lines(lines)

predictions_by_image.append(OCRResult(
text_lines=lines,
languages=lang
))

return predictions_by_image
Loading

0 comments on commit e5c2946

Please sign in to comment.