-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TLDR-709 TLDR-714 update text and classifier extraction benchmark (#456)
- Loading branch information
Showing
16 changed files
with
1,163 additions
and
1,123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
|
||
Orientation predictions: | ||
+-------+-----------+--------+-------+-------+ | ||
| Class | Precision | Recall | F1 | Count | | ||
+=======+===========+========+=======+=======+ | ||
| 0 | 0.998 | 1 | 0.999 | 537 | | ||
+-------+-----------+--------+-------+-------+ | ||
| 90 | 1 | 0.998 | 0.999 | 537 | | ||
+-------+-----------+--------+-------+-------+ | ||
| 180 | 1 | 0.998 | 0.999 | 537 | | ||
+-------+-----------+--------+-------+-------+ | ||
| 270 | 0.998 | 1 | 0.999 | 537 | | ||
+-------+-----------+--------+-------+-------+ | ||
| AVG | 0.999 | 0.999 | 0.999 | None | | ||
+-------+-----------+--------+-------+-------+ | ||
Column predictions: | ||
+-------+-----------+--------+-------+-------+ | ||
| Class | Precision | Recall | F1 | Count | | ||
+=======+===========+========+=======+=======+ | ||
| 1 | 1 | 0.999 | 0.999 | 1692 | | ||
+-------+-----------+--------+-------+-------+ | ||
| 2 | 0.996 | 1 | 0.998 | 456 | | ||
+-------+-----------+--------+-------+-------+ | ||
| AVG | 0.999 | 0.999 | 0.999 | None | | ||
+-------+-----------+--------+-------+-------+ |
This file was deleted.
Oops, something went wrong.
473 changes: 473 additions & 0 deletions
473
resources/benchmarks/tesseract_benchmark_Correction.SAGE_CORRECTION.txt
Large diffs are not rendered by default.
Oops, something went wrong.
443 changes: 443 additions & 0 deletions
443
resources/benchmarks/tesseract_benchmark_Correction.WITHOUT_CORRECTION.txt
Large diffs are not rendered by default.
Oops, something went wrong.
359 changes: 0 additions & 359 deletions
359
resources/benchmarks/tesseract_benchmark_sage-correction.txt
This file was deleted.
Oops, something went wrong.
318 changes: 0 additions & 318 deletions
318
resources/benchmarks/tesseract_benchmark_textblob-correction.txt
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import os | ||
import re | ||
from typing import List, Tuple | ||
|
||
from texttable import Texttable | ||
|
||
|
||
def __parse_ocr_errors(lines: List[str]) -> List: | ||
ocr_errors = [] | ||
matched_errors = [(line_num, line) for line_num, line in enumerate(lines) if "Errors Marked Correct-Generated" in line][0] | ||
for line in lines[matched_errors[0] + 1:]: | ||
# example line: " 2 0 { 6}-{б}" | ||
errors = re.findall(r"(\d+)", line)[0] | ||
chars = re.findall(r"{(.*)}-{(.*)}", line)[0] | ||
ocr_errors.append([errors, chars[0], chars[1]]) | ||
|
||
return ocr_errors | ||
|
||
|
||
def __parse_symbol_info(lines: List[str]) -> Tuple[List, int]: | ||
symbols_info = [] | ||
matched_symbols = [(line_num, line) for line_num, line in enumerate(lines) if "Count Missed %Right" in line][-1] | ||
start_block_line = matched_symbols[0] | ||
|
||
for line in lines[start_block_line + 1:]: | ||
# example line: "1187 11 99.07 {<\n>}" | ||
row_values = [value.strip() for value in re.findall(r"\d+.\d*|{\S+|\W+}", line)] | ||
row_values[-1] = row_values[-1][1:-1] # get symbol value | ||
symbols_info.append(row_values) | ||
# Sort errors | ||
symbols_info = sorted(symbols_info, key=lambda row: int(row[1]), reverse=True) # by missed | ||
|
||
return symbols_info, start_block_line | ||
|
||
|
||
def get_summary_symbol_error(path_reports: str) -> Texttable: | ||
# 1 - call accsum for get summary of all reports | ||
accuracy_script_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "accsum")) | ||
|
||
if os.path.exists(f"{path_reports}/../accsum_report.txt"): | ||
os.remove(f"{path_reports}/../accsum_report.txt") | ||
|
||
file_reports = " ".join([os.path.join(path_reports, f) for f in os.listdir(path_reports) if os.path.isfile(os.path.join(path_reports, f))]) | ||
|
||
command = f"{accuracy_script_path} {file_reports} >> {path_reports}/../accsum_report.txt" | ||
os.system(command) | ||
accsum_report_path = os.path.join(path_reports, "..", "accsum_report.txt") | ||
|
||
# 2 - parse report info | ||
with open(accsum_report_path, "r") as f: | ||
lines = f.readlines() | ||
|
||
symbols_info, start_symbol_block_line = __parse_symbol_info(lines) | ||
ocr_errors = __parse_ocr_errors(lines[:start_symbol_block_line - 1]) | ||
|
||
# 3 - calculate ocr errors for a symbol | ||
ocr_errors_by_symbol = {} | ||
for symbol_info in symbols_info: | ||
ocr_errors_by_symbol[symbol_info[-1]] = [] | ||
for ocr_err in ocr_errors: | ||
if ocr_err[-1] == "" or len(ocr_err[-2]) > 3 or len(ocr_err[-1]) > 3: # to ignore errors with long text (len > 3) or without text | ||
continue | ||
if symbol_info[-1] in ocr_err[-2]: | ||
ocr_errors_by_symbol[symbol_info[-1]].append(f"{ocr_err[0]} & <{ocr_err[1]}> -> <{ocr_err[2]}>") | ||
|
||
# 4 - create table with OCR errors | ||
ocr_err_by_symbol_table = Texttable() | ||
title = [["Symbol", "Cnt Errors & Correct-Generated"]] | ||
ocr_err_by_symbol_table.add_rows(title) | ||
for symbol, value in ocr_errors_by_symbol.items(): | ||
if len(value) != 0: | ||
ocr_err_by_symbol_table.add_row([symbol, value]) | ||
|
||
return ocr_err_by_symbol_table |
41 changes: 41 additions & 0 deletions
41
scripts/text_extraction_benchmark/text_correction/sage_corrector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
|
||
import torch | ||
from sage.spelling_correction import AvailableCorrectors | ||
from sage.spelling_correction import RuM2M100ModelForSpellingCorrection | ||
|
||
|
||
""" | ||
Install sage library (for ocr correction step): | ||
git clone https://github.com/ai-forever/sage.git | ||
cd sage | ||
pip install . | ||
pip install -r requirements.txt | ||
Note: sage use 5.2 Gb GPU ...... | ||
""" | ||
|
||
|
||
class SageCorrector: | ||
|
||
def __init__(self, cache_dir: str, use_gpu: bool = True) -> None: | ||
self.corrected_path = os.path.join(cache_dir, "result_corrected") | ||
os.makedirs(self.corrected_path, exist_ok=True) | ||
|
||
self.corrector = RuM2M100ModelForSpellingCorrection.from_pretrained(AvailableCorrectors.m2m100_1B.value) # 4.49 Gb model (pytorch_model.bin) | ||
self._init_device(use_gpu) | ||
|
||
def _init_device(self, use_gpu: bool) -> None: | ||
if torch.cuda.is_available() and use_gpu: | ||
self.corrector.model.to(torch.device("cuda:0")) | ||
print("use CUDA") | ||
else: | ||
print("use CPU") | ||
|
||
def correction(self, text: str) -> str: | ||
corrected_lines = [] | ||
for line in text.split("\n"): | ||
corrected_lines.append(self.corrector.correct(line)[0]) | ||
corrected_text = "\n".join(corrected_lines) | ||
|
||
return corrected_text |
Oops, something went wrong.