Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alto confidence refactor #45

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 113 additions & 116 deletions pero_ocr/document_ocr/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,94 @@ def to_pagexml(self, file_name, version=PAGEVersion.PAGE_2019_07_15):
with open(file_name, 'w', encoding='utf-8') as out_f:
out_f.write(xml_string)

def alto_get_visual_span(self, line, logprob_len, crop_engine, aligned_word):
line_coords = crop_engine.get_crop_inputs(line.baseline, line.heights, 16)
lm_const = line_coords.shape[1] / logprob_len
extension = 2

while line_coords.size > 0 and extension < 40:
all_x = line_coords[:, max(0, int((aligned_word[0]-extension) * lm_const)):int((aligned_word[1]+extension) * lm_const), 0]
all_y = line_coords[:, max(0, int((aligned_word[0]-extension) * lm_const)):int((aligned_word[1]+extension) * lm_const), 1]

if all_x.size == 0 or all_y.size == 0:
extension += 1
else:
break

if line_coords.size == 0 or all_x.size == 0 or all_y.size == 0:
all_x = line.baseline[:, 0]
all_y = np.concatenate([line.baseline[:, 1] - line.heights[0], line.baseline[:, 1] + line.heights[1]])

return np.min(all_x), np.max(all_x), np.min(all_y), np.max(all_y)

def alto_get_word_confidence(self, confidences, line, word_start, word_len):
word_confidence = None
if line.transcription_confidence == 1:
word_confidence = 1
else:
if confidences.size != 0:
word_confidence = np.quantile(confidences[word_start:word_start + word_len], .50)

return word_confidence

def alto_create_word_child(self, parent, word, confidence, x_min, x_max, y_min, y_max, is_arabic, arabic_helper):
string = ET.SubElement(parent, "String")

if is_arabic:
string.set("CONTENT", arabic_helper.label_form_to_string(word))
else:
string.set("CONTENT", word)

string.set("HEIGHT", str(int((y_max - y_min))))
string.set("WIDTH", str(int((x_max - x_min))))
string.set("VPOS", str(int(y_min)))
string.set("HPOS", str(int(x_min)))

if confidence is not None:
string.set("WC", str(round(confidence, 2)))

def alto_create_space_child(self, parent, x_max, y_min):
space = ET.SubElement(parent, "SP")

space.set("WIDTH", str(4))
space.set("VPOS", str(int(y_min)))
space.set("HPOS", str(int(x_max)))

def alto_get_numeric_labels(self, line):
blank_idx = line.logits.shape[1] - 1
label = []
char_to_num = {c: i for i, c in enumerate(line.characters)}
for item in line.transcription:
if item in char_to_num.keys():
if char_to_num[item] >= blank_idx:
label.append(0)
else:
label.append(char_to_num[item])
else:
label.append(0)

return label

def alto_get_aligned_words(self, line, aligned_letters):
words = []
space_idxs = [pos for pos, char in enumerate(line.transcription) if char == ' ']
space_idxs = [-1] + space_idxs + [len(aligned_letters)]
for i in range(len(space_idxs[1:])):
if space_idxs[i] != space_idxs[i+1]-1:
words.append([aligned_letters[space_idxs[i]+1], aligned_letters[space_idxs[i+1]-1]])

return words

def alto_set_hwvh(self, elem, height, width, v_pos, h_pos):
elem.set("HEIGHT", str(int(height)))
elem.set("WIDTH", str(int(width)))
elem.set("VPOS", str(int(v_pos)))
elem.set("HPOS", str(int(h_pos)))

def to_altoxml_string(self, ocr_processing=None, page_uuid=None, min_line_confidence=0):
arabic_helper = ArabicHelper()
crop_engine = EngineLineCropper(poly=2)

NSMAP = {"xlink": 'http://www.w3.org/1999/xlink',
"xsi": 'http://www.w3.org/2001/XMLSchema-instance'}
root = ET.Element("alto", nsmap=NSMAP)
Expand All @@ -297,6 +383,7 @@ def to_altoxml_string(self, ocr_processing=None, page_uuid=None, min_line_confid
else:
ocr_processing = create_ocr_processing_element()
description.append(ocr_processing)

layout = ET.SubElement(root, "Layout")
page = ET.SubElement(layout, "Page")
if page_uuid is not None:
Expand All @@ -318,15 +405,12 @@ def to_altoxml_string(self, ocr_processing=None, page_uuid=None, min_line_confid
print_space_vpos = self.page_size[0]
print_space_hpos = self.page_size[1]

for b, block in enumerate(self.regions):
for block in self.regions:
text_block = ET.SubElement(print_space, "TextBlock")
text_block.set("ID", 'block_{}' .format(block.id))

text_block_height, text_block_width, text_block_vpos, text_block_hpos = get_hwvh(block.polygon)
text_block.set("HEIGHT", str(int(text_block_height)))
text_block.set("WIDTH", str(int(text_block_width)))
text_block.set("VPOS", str(int(text_block_vpos)))
text_block.set("HPOS", str(int(text_block_hpos)))
self.alto_set_hwvh(text_block, text_block_height, text_block_width, text_block_vpos, text_block_hpos)

print_space_height = max([print_space_vpos + print_space_height, text_block_vpos + text_block_height])
print_space_width = max([print_space_hpos + print_space_width, text_block_hpos + text_block_width])
Expand All @@ -335,141 +419,54 @@ def to_altoxml_string(self, ocr_processing=None, page_uuid=None, min_line_confid
print_space_height = print_space_height - print_space_vpos
print_space_width = print_space_width - print_space_hpos

for l, line in enumerate(block.lines):
for line in block.lines:
if not line.transcription:
continue
arabic_line = False
if arabic_helper.is_arabic_line(line.transcription):
arabic_line = True

arabic_line = arabic_helper.is_arabic_line(line.transcription)

text_line = ET.SubElement(text_block, "TextLine")
text_line_baseline = int(np.average(np.array(line.baseline)[:, 1]))
text_line.set("BASELINE", str(text_line_baseline))

text_line_height, text_line_width, text_line_vpos, text_line_hpos = get_hwvh(line.polygon)

text_line.set("VPOS", str(int(text_line_vpos)))
text_line.set("HPOS", str(int(text_line_hpos)))
text_line.set("HEIGHT", str(int(text_line_height)))
text_line.set("WIDTH", str(int(text_line_width)))
self.alto_set_hwvh(text_line, text_line_height, text_line_width, text_line_vpos, text_line_hpos)

try:
chars = [i for i in range(len(line.characters))]
char_to_num = dict(zip(line.characters, chars))

blank_idx = line.logits.shape[1] - 1

label = []
for item in line.transcription:
if item in char_to_num.keys():
if char_to_num[item] >= blank_idx:
label.append(0)
else:
label.append(char_to_num[item])
else:
label.append(0)

logits = line.get_dense_logits()[line.logit_coords[0]:line.logit_coords[1]]
label = self.alto_get_numeric_labels(line)
logprobs = line.get_full_logprobs()[line.logit_coords[0]:line.logit_coords[1]]
aligned_letters = align_text(-logprobs, np.array(label), blank_idx)
except (ValueError, IndexError, TypeError) as e:
print(f'Error: Alto export, unable to align line {line.id} due to exception {e}.')
line.transcription_confidence = 0
average_word_width = (text_line_hpos + text_line_width) / len(line.transcription.split())
for w, word in enumerate(line.transcription.split()):
line_transcription_confidence = 0
average_word_width = (text_line_hpos + text_line_width) / len(line.transcription.split()) # TODO: should be difference??
for w_id, word in enumerate(line.transcription.split()):
string = ET.SubElement(text_line, "String")
string.set("CONTENT", word)

string.set("HEIGHT", str(int(text_line_height)))
string.set("WIDTH", str(int(average_word_width)))
string.set("VPOS", str(int(text_line_vpos)))
string.set("HPOS", str(int(text_line_hpos + (w * average_word_width))))
self.alto_set_hwvh(string, text_line_height, average_word_width, text_line_vpos, text_line_hpos + (w_id*average_word_width))
else:
crop_engine = EngineLineCropper(poly=2)
line_coords = crop_engine.get_crop_inputs(line.baseline, line.heights, 16)
space_idxs = [pos for pos, char in enumerate(line.transcription) if char == ' ']

words = []
space_idxs = [-1] + space_idxs + [len(aligned_letters)]
for i in range(len(space_idxs[1:])):
if space_idxs[i] != space_idxs[i+1]-1:
words.append([aligned_letters[space_idxs[i]+1], aligned_letters[space_idxs[i+1]-1]])
words = self.alto_get_aligned_words(line, aligned_letters)
splitted_transcription = line.transcription.split()
lm_const = line_coords.shape[1] / logits.shape[0]
letter_counter = 0
confidences = get_line_confidence(line, np.array(label), aligned_letters, logprobs)
#if line.transcription_confidence is None:
line.transcription_confidence = np.quantile(confidences, .50)
for w, word in enumerate(words):
extension = 2
while line_coords.size > 0 and extension < 40:
all_x = line_coords[:, max(0, int((words[w][0]-extension) * lm_const)):int((words[w][1]+extension) * lm_const), 0]
all_y = line_coords[:, max(0, int((words[w][0]-extension) * lm_const)):int((words[w][1]+extension) * lm_const), 1]

if all_x.size == 0 or all_y.size == 0:
extension += 1
else:
break
line_transcription_confidence = np.quantile(confidences, .50)

if line_coords.size == 0 or all_x.size == 0 or all_y.size == 0:
all_x = line.baseline[:, 0]
all_y = np.concatenate([line.baseline[:, 1] - line.heights[0], line.baseline[:, 1] + line.heights[1]])
for w_id, (aligned_word, text_word) in enumerate(zip(words, splitted_transcription)):
x_min, x_max, y_min, y_max = self.alto_get_visual_span(line, logprobs.shape[0], crop_engine, aligned_word)
word_confidence = self.alto_get_word_confidence(confidences, line, letter_counter, len(text_word))
self.alto_create_word_child(text_line, text_word, word_confidence, x_min, x_max, y_min, y_max, arabic_line, arabic_helper)

word_confidence = None
if line.transcription_confidence == 1:
word_confidence = 1
else:
if confidences.size != 0:
word_confidence = np.quantile(confidences[letter_counter:letter_counter+len(splitted_transcription[w])], .50)
if w_id != len(line.transcription.split()) - 1:
self.alto_create_space_child(text_line, x_max, y_min)

string = ET.SubElement(text_line, "String")
letter_counter += len(text_word) + 1

if arabic_line:
string.set("CONTENT", arabic_helper.label_form_to_string(splitted_transcription[w]))
else:
string.set("CONTENT", splitted_transcription[w])

string.set("HEIGHT", str(int((np.max(all_y) - np.min(all_y)))))
string.set("WIDTH", str(int((np.max(all_x) - np.min(all_x)))))
string.set("VPOS", str(int(np.min(all_y))))
string.set("HPOS", str(int(np.min(all_x))))

if word_confidence is not None:
string.set("WC", str(round(word_confidence, 2)))

if w != (len(line.transcription.split())-1):
space = ET.SubElement(text_line, "SP")

space.set("WIDTH", str(4))
space.set("VPOS", str(int(np.min(all_y))))
space.set("HPOS", str(int(np.max(all_x))))
letter_counter += len(splitted_transcription[w])+1
if line.transcription_confidence is not None:
if line.transcription_confidence < min_line_confidence:
text_block.remove(text_line)
top_margin.set("HEIGHT", "{}" .format(int(print_space_vpos)))
top_margin.set("WIDTH", "{}" .format(int(self.page_size[1])))
top_margin.set("VPOS", "0")
top_margin.set("HPOS", "0")

left_margin.set("HEIGHT", "{}" .format(int(self.page_size[0])))
left_margin.set("WIDTH", "{}" .format(int(print_space_hpos)))
left_margin.set("VPOS", "0")
left_margin.set("HPOS", "0")

right_margin.set("HEIGHT", "{}" .format(int(self.page_size[0])))
right_margin.set("WIDTH", "{}" .format(int(self.page_size[1] - (print_space_hpos + print_space_width))))
right_margin.set("VPOS", "0")
right_margin.set("HPOS", "{}" .format(int(print_space_hpos + print_space_width)))

bottom_margin.set("HEIGHT", "{}" .format(int(self.page_size[0] - (print_space_vpos + print_space_height))))
bottom_margin.set("WIDTH", "{}" .format(int(self.page_size[1])))
bottom_margin.set("VPOS", "{}" .format(int(print_space_vpos + print_space_height)))
bottom_margin.set("HPOS", "0")

print_space.set("HEIGHT", str(int(print_space_height)))
print_space.set("WIDTH", str(int(print_space_width)))
print_space.set("VPOS", str(int(print_space_vpos)))
print_space.set("HPOS", str(int(print_space_hpos)))
self.alto_set_hwvh(top_margin, print_space_vpos, self.page_size[1], 0, 0)
self.alto_set_hwvh(left_margin, self.page_size[0], print_space_hpos, 0, 0)
self.alto_set_hwvh(right_margin, self.page_size[0], self.page_size[1] - (print_space_hpos + print_space_width), 0, print_space_hpos + print_space_width)
self.alto_set_hwvh(bottom_margin, self.page_size[0] - (print_space_vpos + print_space_height), self.page_size[1], print_space_vpos + print_space_height, 0)
self.alto_set_hwvh(print_space, print_space_height, print_space_width, print_space_vpos, print_space_hpos)

return ET.tostring(root, pretty_print=True, encoding="utf-8").decode("utf-8")

Expand Down
12 changes: 5 additions & 7 deletions pero_ocr/document_ocr/page_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,16 +462,14 @@ def __init__(self, config, config_path=''):

@staticmethod
def compute_line_confidence(line, threshold=None):
logits = line.get_dense_logits()
log_probs = logits - np.logaddexp.reduce(logits, axis=1)[:, np.newaxis]
log_probs = line.get_full_logprobs()
best_ids = np.argmax(log_probs, axis=-1)
best_probs = np.exp(np.max(log_probs, axis=-1))
worst_best_prob = get_prob(best_ids, best_probs)
# print(worst_best_prob, np.sum(np.exp(best_probs) < threshold), best_probs.shape, np.nonzero(np.exp(best_probs) < threshold))
# for i in np.nonzero(np.exp(best_probs) < threshold)[0]:
# print(best_probs[i-1:i+2], best_ids[i-1:i+2])

return worst_best_prob
blank_id = log_probs.shape[1] - 1
char_probs = [p for p, c in zip(best_ids, best_probs) if c != blank_id]

return np.mean(char_probs)

def update_confidences(self, page_layout):
for line in page_layout.lines_iterator():
Expand Down