diff --git a/pero_ocr/char_confidences.py b/pero_ocr/char_confidences.py index 25c870e..92f0e91 100644 --- a/pero_ocr/char_confidences.py +++ b/pero_ocr/char_confidences.py @@ -1,6 +1,5 @@ import numpy as np - def greedy_filtration(line_probs, chars): idx = -1 text = "" @@ -8,9 +7,9 @@ def greedy_filtration(line_probs, chars): probs = [] for i, (char_index, max_prob) in enumerate(zip(np.argmax(line_probs, axis=1), np.max(line_probs, axis=1))): - if char_index != (line_probs.shape[1] - 1): - if (last_char != chars[char_index]): - text = text + chars[char_index] + if char_index != len(chars) - 1: + if last_char != chars[char_index]: + text += chars[char_index] probs.append([max_prob]) idx += 1 last_char = chars[char_index] @@ -21,6 +20,6 @@ def greedy_filtration(line_probs, chars): last_char = None for i, item in enumerate(probs): - probs[i] = sum(probs[i]) / len(probs[i]) + probs[i] = sum(item) / len(item) return text, probs diff --git a/pero_ocr/transcription_io.py b/pero_ocr/transcription_io.py index fd06e13..20becfe 100644 --- a/pero_ocr/transcription_io.py +++ b/pero_ocr/transcription_io.py @@ -1,5 +1,5 @@ def save_transcriptions(path, transcriptions): - with open(path, 'w') as f: + with open(path, 'w', encoding='utf-8') as f: for key in transcriptions: f.write('{} {}\n'.format(key, transcriptions[key])) @@ -7,9 +7,10 @@ def save_transcriptions(path, transcriptions): def load_transcriptions(path): transcriptions = {} - with open(path, "r") as f: - for line_no, line in enumerate(f): - if len(line) == 0: + with open(path, "r", encoding='utf-8') as f: + for line_no, line in enumerate(f, start=1): + line = line.strip() + if not line: continue try: @@ -24,8 +25,5 @@ def load_transcriptions(path): def parse_transcription_line(line): image_id, transcription = line.split(" ", maxsplit=1) - - if transcription[-1] == '\n': - transcription = transcription[:-1] - + transcription = transcription.rstrip() return image_id, transcription