diff --git a/README.md b/README.md index 59339fb..9cdd5da 100644 --- a/README.md +++ b/README.md @@ -18,24 +18,14 @@ For the current shell session, this can be achieved by setting ``PYTHONPATH`` up export PYTHONPATH=/path/to/the/repo:$PYTHONPATH ``` -As a more permanent solution, a very simplistic `setup.py` is prepared: -``` -python setup.py develop -``` -Beware that the `setup.py` does not promise to bring all the required stuff, e.g. setting CUDA up is up to you. - -Pero can be later removed from your Python distribution by running: -``` -python setup.py develop --uninstall -``` - ## Available models General layout analysis (printed and handwritten) with european printed OCR specialized to czech newspapers can be [downloaded here](https://nextcloud.fit.vutbr.cz/s/NtAbHTNkZFpapdJ). The OCR engine is suitable for most european printed documents. It is specialized for low-quality czech newspapers digitized from microfilms, but it provides very good results for almast all types of printed documents in most languages. If you are interested in processing printed fraktur fonts, handwritten documents or medieval manuscripts, feel free to contact the authors. The newest OCR engines are available at [pero-ocr.fit.vutbr.cz](https://pero-ocr.fit.vutbr.cz). OCR engines are available also through API runing at [pero-ocr.fit.vutbr.cz/api](https://pero-ocr.fit.vutbr.cz/api), [github repository](https://github.com/DCGM/pero-ocr-api). ## Command line application -A command line application is ./user_scripts/parse_folder.py. It is able to process images in a directory using an OCR engine. It can render detected lines in an image and provide document content in Page XML and ALTO XML formats. Additionally, it is able to crop all text lines as rectangular regions of normalized size and save them into separate image files. +A command line application is `./user_scripts/parse_folder.py.` It is able to process images in a directory using an OCR engine. It can render detected lines in an image and provide document content in Page XML and ALTO XML formats. Additionally, it is able to crop all text lines as rectangular regions of normalized size and save them into separate image files. ## Running command line application in container + A docker container can be built from the sourcecode to run scripts and programs based on the pero-ocr. Example of running the `parse_folder.py` script to generate page-xml files for images in input directory: ```shell docker run --rm --tty --interactive \ @@ -63,7 +53,7 @@ import os import configparser import cv2 import numpy as np -from pero_ocr.document_ocr.layout import PageLayout +from pero_ocr.core.layout import PageLayout from pero_ocr.document_ocr.page_parser import PageParser # Read config file. @@ -117,7 +107,7 @@ Currently, only unittests are provided with the code. Some of the code. So simpl ``` #### Simple regression testing -Regression testing can be done by `test/processing_test.sh`. Script calls containerized `parser_folder.py` to process input images and page-xml files and calls user suplied comparison script to compare outputs to example outputs suplied by user. `PERO-OCR` container have to be built in advance to run the test, see 'Running command line application in container' chapter. Script can be called like this: +Regression testing can be done by `test/processing_test.sh`. Script calls containerized `parse_folder.py` to process input images and page-xml files and calls user suplied comparison script to compare outputs to example outputs suplied by user. `PERO-OCR` container have to be built in advance to run the test, see [Running command line application in container](#running-command-line-application-in-container) for more information. Script can be run like this: ```shell sh test/processing_test.sh \ --input-images path/to/input/image/directory \ diff --git a/pero_ocr/core/confidence_estimation.py b/pero_ocr/core/confidence_estimation.py index 921da88..c5176cd 100644 --- a/pero_ocr/core/confidence_estimation.py +++ b/pero_ocr/core/confidence_estimation.py @@ -70,9 +70,12 @@ def squeeze(sequence): return result -def get_line_confidence(line, labels, aligned_letters=None, log_probs=None): +def get_line_confidence(line, labels=None, aligned_letters=None, log_probs=None): # There is the same number of outputs as labels (probably transformer model was used) --> each letter has only one - # possible frame in logits and thus it is not needed to align them + # possible frame in logits thus it is not needed to align them + if labels is None: + labels = line.get_labels() + if line.logits.shape[0] == len(labels): return get_line_confidence_transformer(line, labels) @@ -100,7 +103,7 @@ def get_line_confidence(line, labels, aligned_letters=None, log_probs=None): confidences[i] = max(0, label_prob - other_prob) last_border = next_border - #confidences = confidences / 2 + 0.5 + # confidences = confidences / 2 + 0.5 return confidences diff --git a/pero_ocr/core/layout.py b/pero_ocr/core/layout.py index 7c91594..9817f14 100644 --- a/pero_ocr/core/layout.py +++ b/pero_ocr/core/layout.py @@ -6,6 +6,7 @@ from datetime import datetime, timezone from enum import Enum from typing import Optional, Union, List, Tuple +import unicodedata import numpy as np import lxml.etree as ET @@ -28,6 +29,9 @@ class PAGEVersion(Enum): PAGE_2019_07_15 = 1 PAGE_2013_07_15 = 2 +class ALTOVersion(Enum): + ALTO_v2_x = 1 + ALTO_v4_4 = 2 def log_softmax(x): a = np.logaddexp.reduce(x, axis=1)[:, np.newaxis] @@ -49,7 +53,8 @@ def __init__(self, id: str = None, characters: Optional[List[str]] = None, logit_coords: Optional[Union[List[Tuple[int]], List[Tuple[None]]]] = None, transcription_confidence: Optional[Num] = None, - index: Optional[int] = None): + index: Optional[int] = None, + category: Optional[str] = None): self.id = id self.index = index self.baseline = baseline @@ -61,6 +66,7 @@ def __init__(self, id: str = None, self.characters = characters self.logit_coords = logit_coords self.transcription_confidence = transcription_confidence + self.category = category def get_dense_logits(self, zero_logit_value: int = -80): dense_logits = self.logits.toarray() @@ -71,18 +77,365 @@ def get_full_logprobs(self, zero_logit_value: int = -80): dense_logits = self.get_dense_logits(zero_logit_value) return log_softmax(dense_logits) + def to_pagexml(self, region_element: ET.SubElement, fallback_id: int, validate_id: bool = False): + text_line = ET.SubElement(region_element, "TextLine") + text_line.set("id", export_id(self.id, validate_id)) + if self.index is not None: + text_line.set("index", f'{self.index:d}') + else: + text_line.set("index", f'{fallback_id:d}') + + custom = {} + if self.heights is not None: + heights_out = [np.float64(x) for x in self.heights] + custom['heights'] = list(np.round(heights_out, decimals=1)) + if self.category is not None: + custom['category'] = self.category + if len(custom) > 0: + text_line.set("custom", json.dumps(custom)) + + coords = ET.SubElement(text_line, "Coords") + + if self.polygon is not None: + coords.set("points", coords_to_pagexml_points(self.polygon)) + + if self.baseline is not None: + baseline_element = ET.SubElement(text_line, "Baseline") + baseline_element.set("points", coords_to_pagexml_points(self.baseline)) + + if self.transcription is not None: + text_element = ET.SubElement(text_line, "TextEquiv") + if self.transcription_confidence is not None: + text_element.set("conf", f"{self.transcription_confidence:.3f}") + text_element = ET.SubElement(text_element, "Unicode") + text_element.text = self.transcription + + @classmethod + def from_pagexml(cls, line_element: ET.SubElement, schema, fallback_index: int): + new_textline = cls(id=line_element.attrib['id']) + if 'custom' in line_element.attrib: + new_textline.from_pagexml_parse_custom(line_element.attrib['custom']) + + if 'index' in line_element.attrib: + try: + new_textline.index = int(line_element.attrib['index']) + except ValueError: + pass + + if new_textline.index is None: + new_textline.index = fallback_index + + baseline = line_element.find(schema + 'Baseline') + if baseline is not None: + new_textline.baseline = get_coords_from_pagexml(baseline, schema) + else: + logger.warning(f'Warning: Baseline is missing in TextLine. ' + f'Skipping this line during import. Line ID: {new_textline.id}') + return None + + textline = line_element.find(schema + 'Coords') + if textline is not None: + new_textline.polygon = get_coords_from_pagexml(textline, schema) + + if not new_textline.heights: + guess_line_heights_from_polygon(new_textline, use_center=False, n=len(new_textline.baseline)) + + transcription = line_element.find(schema + 'TextEquiv') + if transcription is not None: + t_unicode = transcription.find(schema + 'Unicode').text + if t_unicode is None: + t_unicode = '' + new_textline.transcription = t_unicode + conf = transcription.get('conf', None) + new_textline.transcription_confidence = float(conf) if conf is not None else None + return new_textline + + def from_pagexml_parse_custom(self, custom_str): + try: + custom = json.loads(custom_str) + self.category = custom.get('category', None) + self.heights = custom.get('heights', None) + except json.decoder.JSONDecodeError: + if 'heights_v2' in custom_str: + for word in custom_str.split(): + if 'heights_v2' in word: + self.heights = json.loads(word.split(":")[1]) + else: + if re.findall("heights", custom_str): + heights = re.findall(r"\d+", custom_str) + heights_array = np.asarray([float(x) for x in heights]) + if heights_array.shape[0] == 4: + heights = np.zeros(2, dtype=np.float32) + heights[0] = heights_array[0] + heights[1] = heights_array[2] + elif heights_array.shape[0] == 3: + heights = np.zeros(2, dtype=np.float32) + heights[0] = heights_array[1] + heights[1] = heights_array[2] - heights_array[0] + else: + heights = heights_array + self.heights = heights.tolist() + + def to_altoxml(self, text_block, arabic_helper, min_line_confidence, version: ALTOVersion): + if self.transcription_confidence is not None and self.transcription_confidence < min_line_confidence: + return + + text_line = ET.SubElement(text_block, "TextLine") + text_line.set("ID", f'line_{self.id}') + text_line.set("BASELINE", self.to_altoxml_baseline(version)) + + text_line_height, text_line_width, text_line_vpos, text_line_hpos = get_hwvh(self.polygon) + + text_line.set("VPOS", str(int(text_line_vpos))) + text_line.set("HPOS", str(int(text_line_hpos))) + text_line.set("HEIGHT", str(int(text_line_height))) + text_line.set("WIDTH", str(int(text_line_width))) + + if self.category == 'text': + self.to_altoxml_text(text_line, arabic_helper, + text_line_height, text_line_width, text_line_vpos, text_line_hpos) + else: + string = ET.SubElement(text_line, "String") + string.set("CONTENT", self.transcription) + + string.set("HEIGHT", str(int(text_line_height))) + string.set("WIDTH", str(int(text_line_width))) + string.set("VPOS", str(int(text_line_vpos))) + string.set("HPOS", str(int(text_line_hpos))) + + if self.transcription_confidence is not None: + string.set("WC", str(round(self.transcription_confidence, 2))) + + def get_labels(self): + chars = [i for i in range(len(self.characters))] + char_to_num = dict(zip(self.characters, chars)) + + blank_idx = self.logits.shape[1] - 1 + + labels = [] + for item in self.transcription: + if item in char_to_num.keys(): + if char_to_num[item] >= blank_idx: + labels.append(0) + else: + labels.append(char_to_num[item]) + else: + labels.append(0) + return np.array(labels) + + def to_altoxml_text(self, text_line, arabic_helper, + text_line_height, text_line_width, text_line_vpos, text_line_hpos): + arabic_line = False + if arabic_helper.is_arabic_line(self.transcription): + arabic_line = True + + logits = None + logprobs = None + aligned_letters = None + try: + label = self.get_labels() + blank_idx = self.logits.shape[1] - 1 + + logits = self.get_dense_logits()[self.logit_coords[0]:self.logit_coords[1]] + logprobs = self.get_full_logprobs()[self.logit_coords[0]:self.logit_coords[1]] + aligned_letters = align_text(-logprobs, np.array(label), blank_idx) + except (ValueError, IndexError, TypeError) as e: + logger.warning(f'Error: Alto export, unable to align line {self.id} due to exception: {e}.') + + if logits is not None and logits.shape[0] > 0: + max_val = np.max(logits, axis=1) + logits = logits - max_val[:, np.newaxis] + probs = np.exp(logits) + probs = probs / np.sum(probs, axis=1, keepdims=True) + probs = np.max(probs, axis=1) + self.transcription_confidence = np.quantile(probs, .50) + else: + self.transcription_confidence = 0.0 + + average_word_width = (text_line_hpos + text_line_width) / len(self.transcription.split()) + for w, word in enumerate(self.transcription.split()): + string = ET.SubElement(text_line, "String") + string.set("CONTENT", word) + + string.set("HEIGHT", str(int(text_line_height))) + string.set("WIDTH", str(int(average_word_width))) + string.set("VPOS", str(int(text_line_vpos))) + string.set("HPOS", str(int(text_line_hpos + (w * average_word_width)))) + else: + crop_engine = EngineLineCropper(poly=2) + line_coords = crop_engine.get_crop_inputs(self.baseline, self.heights, 16) + space_idxs = [pos for pos, char in enumerate(self.transcription) if char == ' '] + + words = [] + space_idxs = [-1] + space_idxs + [len(aligned_letters)] + for i in range(len(space_idxs[1:])): + if space_idxs[i] != space_idxs[i + 1] - 1: + words.append([aligned_letters[space_idxs[i] + 1], aligned_letters[space_idxs[i + 1] - 1]]) + splitted_transcription = self.transcription.split() + lm_const = line_coords.shape[1] / logits.shape[0] + letter_counter = 0 + confidences = get_line_confidence(self, np.array(label), aligned_letters, logprobs) + # if self.transcription_confidence is None: + self.transcription_confidence = np.quantile(confidences, .50) + for w, word in enumerate(words): + extension = 2 + while line_coords.size > 0 and extension < 40: + all_x = line_coords[:, + max(0, int((words[w][0] - extension) * lm_const)):int((words[w][1] + extension) * lm_const), + 0] + all_y = line_coords[:, + max(0, int((words[w][0] - extension) * lm_const)):int((words[w][1] + extension) * lm_const), + 1] + + if all_x.size == 0 or all_y.size == 0: + extension += 1 + else: + break + + if line_coords.size == 0 or all_x.size == 0 or all_y.size == 0: + all_x = self.baseline[:, 0] + all_y = np.concatenate( + [self.baseline[:, 1] - self.heights[0], self.baseline[:, 1] + self.heights[1]]) + + word_confidence = None + if self.transcription_confidence == 1: + word_confidence = 1 + else: + if confidences.size != 0: + word_confidence = np.quantile( + confidences[letter_counter:letter_counter + len(splitted_transcription[w])], .50) + + string = ET.SubElement(text_line, "String") + + if arabic_line: + string.set("CONTENT", arabic_helper.label_form_to_string(splitted_transcription[w])) + else: + string.set("CONTENT", splitted_transcription[w]) + + string.set("HEIGHT", str(int((np.max(all_y) - np.min(all_y))))) + string.set("WIDTH", str(int((np.max(all_x) - np.min(all_x))))) + string.set("VPOS", str(int(np.min(all_y)))) + string.set("HPOS", str(int(np.min(all_x)))) + + if word_confidence is not None: + string.set("WC", str(round(word_confidence, 2))) + + if w != (len(self.transcription.split()) - 1): + space = ET.SubElement(text_line, "SP") + + space.set("WIDTH", str(4)) + space.set("VPOS", str(int(np.min(all_y)))) + space.set("HPOS", str(int(np.max(all_x)))) + letter_counter += len(splitted_transcription[w]) + 1 + + def to_altoxml_baseline(self, version: ALTOVersion) -> str: + if version == ALTOVersion.ALTO_v2_x: + # ALTO 4.1 and older accept baseline only as a single point + baseline = int(np.round(np.average(np.array(self.baseline)[:, 1]))) + return str(baseline) + elif version == ALTOVersion.ALTO_v4_4: + # ALTO 4.2 and newer accept baseline as a string with list of points. Recommended "x1,y1 x2,y2 ..." format. + baseline_points = [f"{x},{y}" for x, y in np.round(self.baseline).astype('int')] + baseline_points = " ".join(baseline_points) + return baseline_points + else: + return "" + + @classmethod + def from_altoxml(cls, line: ET.SubElement, schema): + hpos = int(line.attrib['HPOS']) + vpos = int(line.attrib['VPOS']) + width = int(line.attrib['WIDTH']) + height = int(line.attrib['HEIGHT']) + baseline_str = line.attrib['BASELINE'] + baseline, heights, polygon = cls.from_altoxml_polygon(baseline_str, hpos, vpos, width, height) + + new_textline = cls(id=line.attrib['ID'], baseline=baseline, heights=heights, polygon=polygon) + + word = '' + start = True + for text in line.iter(schema + 'String'): + if start: + start = False + word = word + text.get('CONTENT') + else: + word = word + " " + text.get('CONTENT') + new_textline.transcription = word + return new_textline + + @staticmethod + def from_altoxml_polygon(baseline_str, hpos, vpos, width, height) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + baseline = baseline_str.strip().split(' ') + + if len(baseline) == 1: + # baseline is only one number (probably ALTOversion = 2.x) + try: + baseline = float(baseline[0]) + except ValueError: + baseline = vpos + height # fallback: baseline is at the bottom of the bounding box, heights[1] = 0 + + baseline_arr = np.asarray([[hpos, baseline], [hpos + width, baseline]]) + heights = np.asarray([baseline - vpos, vpos + height - baseline]) + polygon = np.asarray([[hpos, vpos], + [hpos + width, vpos], + [hpos + width, vpos + height], + [hpos, vpos + height]]) + return baseline_arr, heights, polygon + else: + # baseline is list of points (probably ALTOversion = 4.4) + baseline_coords = [t.split(",") for t in baseline] + baseline = np.asarray([[int(round(float(x))), int(round(float(y)))] for x, y in baseline_coords]) + + # count heights from the FIRST element of baseline + heights = np.asarray([baseline[0, 1] - vpos, vpos + height - baseline[0, 1]]) + + coords_top = [[x, y - heights[0]] for x, y in baseline] + coords_bottom = [[x, y + heights[1]] for x, y in baseline] + # reverse coords_bottom to create polygon in clockwise order + coords_bottom.reverse() + polygon = np.concatenate([coords_top, coords_bottom, coords_top[:1]]) + + return baseline, heights, polygon + class RegionLayout(object): def __init__(self, id: str, polygon: np.ndarray, - region_type=None): + region_type: Optional[str] = None, + category: Optional[str] = None, + detection_confidence: Optional[float] = None): self.id = id # ID string self.polygon = polygon # bounding polygon self.region_type = region_type + self.category = category self.lines: List[TextLine] = [] self.transcription = None + self.detection_confidence = detection_confidence + + def get_lines_of_category(self, categories: Union[str, list]): + if isinstance(categories, str): + categories = [categories] - def to_page_xml(self, page_element: ET.SubElement, validate_id: bool = False): + return [line for line in self.lines if line.category in categories] + + def replace_id(self, new_id): + """Replace region ID and all IDs in TextLines which has region ID inside them.""" + for line in self.lines: + line.id = line.id.replace(self.id, new_id) + self.id = new_id + + def get_polygon_bounding_box(self) -> Tuple[int, int, int, int]: + """Get bounding box of region polygon which includes all polygon points. + :return: tuple[int, int, int, int]: (x_min, y_min, x_max, y_max) + """ + x_min = min(self.polygon[:, 0]) + x_max = max(self.polygon[:, 0]) + y_min = min(self.polygon[:, 1]) + y_max = max(self.polygon[:, 1]) + + return x_min, y_min, x_max, y_max + + def to_pagexml(self, page_element: ET.SubElement, validate_id: bool = False): region_element = ET.SubElement(page_element, "TextRegion") coords = ET.SubElement(region_element, "Coords") region_element.set("id", export_id(self.id, validate_id)) @@ -90,17 +443,105 @@ def to_page_xml(self, page_element: ET.SubElement, validate_id: bool = False): if self.region_type is not None: region_element.set("type", self.region_type) - points = ["{},{}".format(int(np.round(coord[0])), int(np.round(coord[1]))) for coord in self.polygon] - points = " ".join(points) - coords.set("points", points) + custom = {} + if self.category is not None: + custom['category'] = self.category + if self.detection_confidence is not None: + custom['detection_confidence'] = round(self.detection_confidence, 3) + if len(custom) > 0: + custom = json.dumps(custom) + region_element.set("custom", custom) + + coords.set("points", coords_to_pagexml_points(self.polygon)) + if self.transcription is not None: text_element = ET.SubElement(region_element, "TextEquiv") text_element = ET.SubElement(text_element, "Unicode") text_element.text = self.transcription + + for i, line in enumerate(self.lines): + line.to_pagexml(region_element, fallback_id=i, validate_id=validate_id) + return region_element + @classmethod + def from_pagexml(cls, region_element: ET.SubElement, schema): + coords_element = region_element.find(schema + 'Coords') + region_coords = get_coords_from_pagexml(coords_element, schema) + + region_type = None + if "type" in region_element.attrib: + region_type = region_element.attrib["type"] + + category = None + detection_confidence = None + if "custom" in region_element.attrib: + custom = json.loads(region_element.attrib["custom"]) + category = custom.get('category', None) + detection_confidence = custom.get('detection_confidence', None) + + layout_region = cls(region_element.attrib['id'], region_coords, region_type, + category=category, + detection_confidence=detection_confidence) + + transcription = region_element.find(schema + 'TextEquiv') + if transcription is not None: + layout_region.transcription = transcription.find(schema + 'Unicode').text + if layout_region.transcription is None: + layout_region.transcription = '' + + for i, line in enumerate(region_element.iter(schema + 'TextLine')): + new_textline = TextLine.from_pagexml(line, schema, fallback_index=i) + if new_textline is not None: + layout_region.lines.append(new_textline) + + return layout_region + + def to_altoxml(self, print_space, arabic_helper, min_line_confidence, + print_space_coords: Tuple[int, int, int, int], version: ALTOVersion) -> Tuple[int, int, int, int]: + print_space_height, print_space_width, print_space_vpos, print_space_hpos = print_space_coords + + text_block = ET.SubElement(print_space, "TextBlock") + text_block.set("ID", 'block_{}'.format(self.id)) + + text_block_height, text_block_width, text_block_vpos, text_block_hpos = get_hwvh(self.polygon) + text_block.set("HEIGHT", str(int(text_block_height))) + text_block.set("WIDTH", str(int(text_block_width))) + text_block.set("VPOS", str(int(text_block_vpos))) + text_block.set("HPOS", str(int(text_block_hpos))) + + print_space_height = max([print_space_vpos + print_space_height, text_block_vpos + text_block_height]) + print_space_width = max([print_space_hpos + print_space_width, text_block_hpos + text_block_width]) + print_space_vpos = min([print_space_vpos, text_block_vpos]) + print_space_hpos = min([print_space_hpos, text_block_hpos]) + print_space_height = print_space_height - print_space_vpos + print_space_width = print_space_width - print_space_hpos + + for line in self.lines: + if not line.transcription or line.transcription.strip() == "": + continue + line.to_altoxml(text_block, arabic_helper, min_line_confidence, version) + return print_space_height, print_space_width, print_space_vpos, print_space_hpos + + @classmethod + def from_altoxml(cls, text_block: ET.SubElement, schema): + region_coords = list() + region_coords.append([int(text_block.get('HPOS')), int(text_block.get('VPOS'))]) + region_coords.append([int(text_block.get('HPOS')) + int(text_block.get('WIDTH')), int(text_block.get('VPOS'))]) + region_coords.append([int(text_block.get('HPOS')) + int(text_block.get('WIDTH')), + int(text_block.get('VPOS')) + int(text_block.get('HEIGHT'))]) + region_coords.append([int(text_block.get('HPOS')), int(text_block.get('VPOS')) + int(text_block.get('HEIGHT'))]) + + region_layout = cls(text_block.attrib['ID'], np.asarray(region_coords).tolist()) + + for line in text_block.iter(schema + 'TextLine'): + new_textline = TextLine.from_altoxml(line, schema) + region_layout.lines.append(new_textline) + + return region_layout + -def get_coords_form_page_xml(coords_element, schema): +def get_coords_from_pagexml(coords_element, schema): if 'points' in coords_element.attrib: coords = points_string_to_array(coords_element.attrib['points']) else: @@ -112,22 +553,11 @@ def get_coords_form_page_xml(coords_element, schema): return coords -def get_region_from_page_xml(region_element, schema): - coords_element = region_element.find(schema + 'Coords') - region_coords = get_coords_form_page_xml(coords_element, schema) - - region_type = None - if "type" in region_element.attrib: - region_type = region_element.attrib["type"] - - layout_region = RegionLayout(region_element.attrib['id'], region_coords, region_type) - - transcription = region_element.find(schema + 'TextEquiv') - if transcription is not None: - layout_region.transcription = transcription.find(schema + 'Unicode').text - if layout_region.transcription is None: - layout_region.transcription = '' - return layout_region +def coords_to_pagexml_points(polygon: np.ndarray) -> str: + polygon = np.round(polygon).astype(np.dtype('int')) + points = [f"{x},{y}" for x, y in np.maximum(polygon, 0)] + points = " ".join(points) + return points def guess_line_heights_from_polygon(text_line: TextLine, use_center: bool = False, n: int = 10, interpolate=False): @@ -258,66 +688,7 @@ def from_pagexml(self, file: Union[str, BytesIO]): self.reading_order = get_reading_order(page, schema) for region in page_tree.iter(schema + 'TextRegion'): - region_layout = get_region_from_page_xml(region, schema) - - for line_i, line in enumerate(region.iter(schema + 'TextLine')): - new_textline = TextLine(id=line.attrib['id']) - if 'custom' in line.attrib: - custom_str = line.attrib['custom'] - if 'heights_v2' in custom_str: - for word in custom_str.split(): - if 'heights_v2' in word: - new_textline.heights = json.loads(word.split(":")[1]) - else: - if re.findall("heights", line.attrib['custom']): - heights = re.findall("\d+", line.attrib['custom']) - heights_array = np.asarray([float(x) for x in heights]) - if heights_array.shape[0] == 4: - heights = np.zeros(2, dtype=np.float32) - heights[0] = heights_array[0] - heights[1] = heights_array[2] - elif heights_array.shape[0] == 3: - heights = np.zeros(2, dtype=np.float32) - heights[0] = heights_array[1] - heights[1] = heights_array[2] - heights_array[0] - else: - heights = heights_array - new_textline.heights = heights.tolist() - - if 'index' in line.attrib: - try: - new_textline.index = int(line.attrib['index']) - except ValueError: - pass - - if new_textline.index is None: - new_textline.index = line_i - - baseline = line.find(schema + 'Baseline') - if baseline is not None: - new_textline.baseline = get_coords_form_page_xml(baseline, schema) - else: - logger.warning(f'Warning: Baseline is missing in TextLine. ' - f'Skipping this line during import. Line ID: {new_textline.id} Page ID: {self.id}') - continue - - textline = line.find(schema + 'Coords') - if textline is not None: - new_textline.polygon = get_coords_form_page_xml(textline, schema) - - if not new_textline.heights: - guess_line_heights_from_polygon(new_textline, use_center=False, n=len(new_textline.baseline)) - - transcription = line.find(schema + 'TextEquiv') - if transcription is not None: - t_unicode = transcription.find(schema + 'Unicode').text - if t_unicode is None: - t_unicode = '' - new_textline.transcription = t_unicode - conf = transcription.get('conf', None) - new_textline.transcription_confidence = float(conf) if conf is not None else None - region_layout.lines.append(new_textline) - + region_layout = RegionLayout.from_pagexml(region, schema) self.regions.append(region_layout) def to_pagexml_string(self, creator: str = 'Pero OCR', validate_id: bool = False, @@ -350,44 +721,12 @@ def to_pagexml_string(self, creator: str = 'Pero OCR', validate_id: bool = False page.set("imageWidth", str(self.page_size[1])) page.set("imageHeight", str(self.page_size[0])) - if self.reading_order is not None: + if self.reading_order is not None and self.reading_order != {}: self.sort_regions_by_reading_order() - self.reading_order_to_page_xml(page) + self.reading_order_to_pagexml(page) for region_layout in self.regions: - text_region = region_layout.to_page_xml(page, validate_id=validate_id) - - for i, line in enumerate(region_layout.lines): - text_line = ET.SubElement(text_region, "TextLine") - text_line.set("id", export_id(line.id, validate_id)) - if line.index is not None: - text_line.set("index", f'{line.index:d}') - else: - text_line.set("index", f'{i:d}') - if line.heights is not None: - text_line.set("custom", f"heights_v2:[{line.heights[0]:.1f},{line.heights[1]:.1f}]") - - coords = ET.SubElement(text_line, "Coords") - - if line.polygon is not None: - points = ["{},{}".format(int(np.round(coord[0])), int(np.round(coord[1]))) for coord in - line.polygon] - points = " ".join(points) - coords.set("points", points) - - if line.baseline is not None: - baseline_element = ET.SubElement(text_line, "Baseline") - points = ["{},{}".format(int(np.round(coord[0])), int(np.round(coord[1]))) for coord in - line.baseline] - points = " ".join(points) - baseline_element.set("points", points) - - if line.transcription is not None: - text_element = ET.SubElement(text_line, "TextEquiv") - if line.transcription_confidence is not None: - text_element.set("conf", f"{line.transcription_confidence:.3f}") - text_element = ET.SubElement(text_element, "Unicode") - text_element.text = line.transcription + region_layout.to_pagexml(page, validate_id=validate_id) return ET.tostring(root, pretty_print=True, encoding="utf-8", xml_declaration=True).decode("utf-8") @@ -397,12 +736,17 @@ def to_pagexml(self, file_name: str, creator: str = 'Pero OCR', with open(file_name, 'w', encoding='utf-8') as out_f: out_f.write(xml_string) - def to_altoxml_string(self, ocr_processing_element: ET.SubElement = None, page_uuid: str = None, min_line_confidence: float = 0): + def to_altoxml_string(self, ocr_processing_element: ET.SubElement = None, page_uuid: str = None, + min_line_confidence: float = 0, version: ALTOVersion = ALTOVersion.ALTO_v2_x): arabic_helper = ArabicHelper() NSMAP = {"xlink": 'http://www.w3.org/1999/xlink', "xsi": 'http://www.w3.org/2001/XMLSchema-instance'} root = ET.Element("alto", nsmap=NSMAP) - root.set("xmlns", "http://www.loc.gov/standards/alto/ns-v2#") + + if version == ALTOVersion.ALTO_v4_4: + root.set("xmlns", "http://www.loc.gov/standards/alto/ns-v4#") + elif version == ALTOVersion.ALTO_v2_x: + root.set("xmlns", "http://www.loc.gov/standards/alto/ns-v2#") description = ET.SubElement(root, "Description") measurement_unit = ET.SubElement(description, "MeasurementUnit") @@ -435,147 +779,13 @@ def to_altoxml_string(self, ocr_processing_element: ET.SubElement = None, page_u print_space_width = 0 print_space_vpos = self.page_size[0] print_space_hpos = self.page_size[1] + print_space_coords = (print_space_height, print_space_width, print_space_vpos, print_space_hpos) - for b, block in enumerate(self.regions): - text_block = ET.SubElement(print_space, "TextBlock") - text_block.set("ID", 'block_{}' .format(block.id)) - - text_block_height, text_block_width, text_block_vpos, text_block_hpos = get_hwvh(block.polygon) - text_block.set("HEIGHT", str(int(text_block_height))) - text_block.set("WIDTH", str(int(text_block_width))) - text_block.set("VPOS", str(int(text_block_vpos))) - text_block.set("HPOS", str(int(text_block_hpos))) - - print_space_height = max([print_space_vpos + print_space_height, text_block_vpos + text_block_height]) - print_space_width = max([print_space_hpos + print_space_width, text_block_hpos + text_block_width]) - print_space_vpos = min([print_space_vpos, text_block_vpos]) - print_space_hpos = min([print_space_hpos, text_block_hpos]) - print_space_height = print_space_height - print_space_vpos - print_space_width = print_space_width - print_space_hpos - - for l, line in enumerate(block.lines): - if not line.transcription or line.transcription.strip() == "": - continue - arabic_line = False - if arabic_helper.is_arabic_line(line.transcription): - arabic_line = True - text_line = ET.SubElement(text_block, "TextLine") - text_line_baseline = int(np.average(np.array(line.baseline)[:, 1])) - text_line.set("BASELINE", str(text_line_baseline)) - - text_line_height, text_line_width, text_line_vpos, text_line_hpos = get_hwvh(line.polygon) - - text_line.set("VPOS", str(int(text_line_vpos))) - text_line.set("HPOS", str(int(text_line_hpos))) - text_line.set("HEIGHT", str(int(text_line_height))) - text_line.set("WIDTH", str(int(text_line_width))) - - logits = None - logprobs = None - aligned_letters = None - try: - chars = [i for i in range(len(line.characters))] - char_to_num = dict(zip(line.characters, chars)) - - blank_idx = line.logits.shape[1] - 1 - - label = [] - for item in line.transcription: - if item in char_to_num.keys(): - if char_to_num[item] >= blank_idx: - label.append(0) - else: - label.append(char_to_num[item]) - else: - label.append(0) - - logits = line.get_dense_logits()[line.logit_coords[0]:line.logit_coords[1]] - logprobs = line.get_full_logprobs()[line.logit_coords[0]:line.logit_coords[1]] - aligned_letters = align_text(-logprobs, np.array(label), blank_idx) - except (ValueError, IndexError, TypeError) as e: - logger.warning(f'Error: Alto export, unable to align line {line.id} due to exception {e}.') - - if logits is not None: - max_val = np.max(logits, axis=1) - logits = logits - max_val[:, np.newaxis] - probs = np.exp(logits) - probs = probs / np.sum(probs, axis=1, keepdims=True) - probs = np.max(probs, axis=1) - line.transcription_confidence = np.quantile(probs, .50) - else: - line.transcription_confidence = 0 - average_word_width = (text_line_hpos + text_line_width) / len(line.transcription.split()) - for w, word in enumerate(line.transcription.split()): - string = ET.SubElement(text_line, "String") - string.set("CONTENT", word) - - string.set("HEIGHT", str(int(text_line_height))) - string.set("WIDTH", str(int(average_word_width))) - string.set("VPOS", str(int(text_line_vpos))) - string.set("HPOS", str(int(text_line_hpos + (w * average_word_width)))) - else: - crop_engine = EngineLineCropper(poly=2) - line_coords = crop_engine.get_crop_inputs(line.baseline, line.heights, 16) - space_idxs = [pos for pos, char in enumerate(line.transcription) if char == ' '] - - words = [] - space_idxs = [-1] + space_idxs + [len(aligned_letters)] - for i in range(len(space_idxs[1:])): - if space_idxs[i] != space_idxs[i+1]-1: - words.append([aligned_letters[space_idxs[i]+1], aligned_letters[space_idxs[i+1]-1]]) - splitted_transcription = line.transcription.split() - lm_const = line_coords.shape[1] / logits.shape[0] - letter_counter = 0 - confidences = get_line_confidence(line, np.array(label), aligned_letters, logprobs) - #if line.transcription_confidence is None: - line.transcription_confidence = np.quantile(confidences, .50) - for w, word in enumerate(words): - extension = 2 - while line_coords.size > 0 and extension < 40: - all_x = line_coords[:, max(0, int((words[w][0]-extension) * lm_const)):int((words[w][1]+extension) * lm_const), 0] - all_y = line_coords[:, max(0, int((words[w][0]-extension) * lm_const)):int((words[w][1]+extension) * lm_const), 1] + for block in self.regions: + print_space_coords = block.to_altoxml(print_space, arabic_helper, min_line_confidence, print_space_coords, version) - if all_x.size == 0 or all_y.size == 0: - extension += 1 - else: - break - - if line_coords.size == 0 or all_x.size == 0 or all_y.size == 0: - all_x = line.baseline[:, 0] - all_y = np.concatenate([line.baseline[:, 1] - line.heights[0], line.baseline[:, 1] + line.heights[1]]) - - word_confidence = None - if line.transcription_confidence == 1: - word_confidence = 1 - else: - if confidences.size != 0: - word_confidence = np.quantile(confidences[letter_counter:letter_counter+len(splitted_transcription[w])], .50) + print_space_height, print_space_width, print_space_vpos, print_space_hpos = print_space_coords - string = ET.SubElement(text_line, "String") - - if arabic_line: - string.set("CONTENT", arabic_helper.label_form_to_string(splitted_transcription[w])) - else: - string.set("CONTENT", splitted_transcription[w]) - - string.set("HEIGHT", str(int((np.max(all_y) - np.min(all_y))))) - string.set("WIDTH", str(int((np.max(all_x) - np.min(all_x))))) - string.set("VPOS", str(int(np.min(all_y)))) - string.set("HPOS", str(int(np.min(all_x)))) - - if word_confidence is not None: - string.set("WC", str(round(word_confidence, 2))) - - if w != (len(line.transcription.split())-1): - space = ET.SubElement(text_line, "SP") - - space.set("WIDTH", str(4)) - space.set("VPOS", str(int(np.min(all_y)))) - space.set("HPOS", str(int(np.max(all_x)))) - letter_counter += len(splitted_transcription[w])+1 - if line.transcription_confidence is not None: - if line.transcription_confidence < min_line_confidence: - text_block.remove(text_line) top_margin.set("HEIGHT", "{}" .format(int(print_space_vpos))) top_margin.set("WIDTH", "{}" .format(int(self.page_size[1]))) top_margin.set("VPOS", "0") @@ -603,8 +813,9 @@ def to_altoxml_string(self, ocr_processing_element: ET.SubElement = None, page_u return ET.tostring(root, pretty_print=True, encoding="utf-8", xml_declaration=True).decode("utf-8") - def to_altoxml(self, file_name: str, ocr_processing_element: ET.SubElement = None, page_uuid: str = None): - alto_string = self.to_altoxml_string(ocr_processing_element=ocr_processing_element, page_uuid=page_uuid) + def to_altoxml(self, file_name: str, ocr_processing_element: ET.SubElement = None, page_uuid: str = None, + version: ALTOVersion = ALTOVersion.ALTO_v2_x): + alto_string = self.to_altoxml_string(ocr_processing_element=ocr_processing_element, page_uuid=page_uuid, version=version) with open(file_name, 'w', encoding='utf-8') as out_f: out_f.write(alto_string) @@ -624,48 +835,13 @@ def from_altoxml(self, file: Union[str, BytesIO]): print_space = page.findall(schema + 'PrintSpace')[0] for region in print_space.iter(schema + 'TextBlock'): - region_coords = list() - region_coords.append([int(region.get('HPOS')), int(region.get('VPOS'))]) - region_coords.append([int(region.get('HPOS')) + int(region.get('WIDTH')), int(region.get('VPOS'))]) - region_coords.append([int(region.get('HPOS')) + int(region.get('WIDTH')), - int(region.get('VPOS')) + int(region.get('HEIGHT'))]) - region_coords.append([int(region.get('HPOS')), int(region.get('VPOS')) + int(region.get('HEIGHT'))]) - - region_layout = RegionLayout(region.attrib['ID'], np.asarray(region_coords).tolist()) - - for line in region.iter(schema + 'TextLine'): - new_textline = TextLine(baseline=np.asarray( - [[int(line.attrib['HPOS']), int(line.attrib['BASELINE'])], - [int(line.attrib['HPOS']) + int(line.attrib['WIDTH']), int(line.attrib['BASELINE'])]])) - polygon = [] - new_textline.heights = np.asarray([ - int(line.attrib['HEIGHT']) + int(line.attrib['VPOS']) - int(line.attrib['BASELINE']), - int(line.attrib['BASELINE']) - int(line.attrib['VPOS'])]) - polygon.append([int(line.attrib['HPOS']), int(line.attrib['VPOS'])]) - polygon.append( - [int(line.attrib['HPOS']) + int(line.attrib['WIDTH']), int(line.attrib['VPOS'])]) - polygon.append([int(line.attrib['HPOS']) + int(line.attrib['WIDTH']), - int(line.attrib['VPOS']) + int(line.attrib['HEIGHT'])]) - polygon.append( - [int(line.attrib['HPOS']), int(line.attrib['VPOS']) + int(line.attrib['HEIGHT'])]) - new_textline.polygon = np.asarray(polygon) - word = '' - start = True - for text in line.iter(schema + 'String'): - if start: - start = False - word = word + text.get('CONTENT') - else: - word = word + " " + text.get('CONTENT') - new_textline.transcription = word - region_layout.lines.append(new_textline) - + region_layout = RegionLayout.from_altoxml(region, schema) self.regions.append(region_layout) def sort_regions_by_reading_order(self): self.regions = sorted(self.regions, key=lambda k: self.reading_order[k] if k in self.reading_order else float("inf")) - def reading_order_to_page_xml(self, page_element: ET.SubElement): + def reading_order_to_pagexml(self, page_element: ET.SubElement): reading_order_element = ET.SubElement(page_element, "ReadingOrder") ordered_group_element = ET.SubElement(reading_order_element, "OrderedGroup") ordered_group_element.set("id", "reading_order") @@ -746,9 +922,12 @@ def load_logits(self, file: str): line.characters = characters[line.id] line.logit_coords = logit_coords[line.id] - def render_to_image(self, image, thickness: int = 2, circles: bool = True, render_order: bool = False): + def render_to_image(self, image, thickness: int = 2, circles: bool = True, + render_order: bool = False, render_category: bool = False): """Render layout into image. :param image: image to render layout into + :param render_order: render region order number given by enumerate(regions) to the middle of given region + :param render_region_id: render region id to the upper left corner of given region """ for region_layout in self.regions: image = draw_lines( @@ -764,28 +943,38 @@ def render_to_image(self, image, thickness: int = 2, circles: bool = True, rende [region_layout.polygon], color=(255, 0, 0), circles=(circles, circles, circles), close=True, thickness=thickness) - if render_order: + if render_order or render_category: font = cv2.FONT_HERSHEY_DUPLEX - font_scale = 4 - font_thickness = 5 + font_scale = 1 + font_thickness = 1 for idx, region in enumerate(self.regions): - min = region.polygon.min(axis=0) - max = region.polygon.max(axis=0) - - text_w, text_h = cv2.getTextSize(f"{idx}", font, font_scale, font_thickness)[0] - - mid_coords = (int((min[0] + max[0]) // 2 - text_w // 2), int((min[1] + max[1]) // 2 + text_h // 2)) - - cv2.putText(image, f"{idx}", mid_coords, font, font_scale, - (0, 0, 0), thickness=font_thickness, lineType=cv2.LINE_AA) + min_p = region.polygon.min(axis=0) + max_p = region.polygon.max(axis=0) + + if render_order: + text = f"{idx}" + text_w, text_h = cv2.getTextSize(text, font, font_scale, font_thickness)[0] + mid_x = int((min_p[0] + max_p[0]) // 2 - text_w // 2) + mid_y = int((min_p[1] + max_p[1]) // 2 + text_h // 2) + cv2.putText(image, text, (mid_x, mid_y), font, font_scale, + color=(0, 0, 0), thickness=font_thickness, lineType=cv2.LINE_AA) + if render_category and region.category not in [None, 'text']: + text = f"{normalize_text(region.category)}" + text_w, text_h = cv2.getTextSize(text, font, font_scale, font_thickness)[0] + start_point = (int(min_p[0]), int(min_p[1])) + end_point = (int(min_p[0]) + text_w, int(min_p[1]) - text_h) + cv2.rectangle(image, start_point, end_point, color=(255, 0, 0), thickness=-1) + cv2.putText(image, text, start_point, font, font_scale, + color=(255, 255, 255), thickness=font_thickness, lineType=cv2.LINE_AA) return image - def lines_iterator(self): + def lines_iterator(self, categories: list = None): for region in self.regions: for line in region.lines: - yield line + if not categories or line.category in categories: + yield line def get_quality(self, x: int = None, y: int = None, width: int = None, height: int = None, power: int = 6): bbox_confidences = [] @@ -858,6 +1047,14 @@ def get_quality(self, x: int = None, y: int = None, width: int = None, height: i else: return -1 + def rename_region_id(self, old_id, new_id): + for region in self.regions: + if region.id == old_id: + region.replace_id(new_id) + break + else: + raise ValueError(f'Region with id {old_id} not found.') + def draw_lines(img, lines, color=(255, 0, 0), circles=(False, False, False), close=False, thickness=2): """Draw a line into image. @@ -947,3 +1144,7 @@ def create_ocr_processing_element(id: str = "IdOcr", return ocr_processing + +def normalize_text(text: str) -> str: + """Normalize text to ASCII characters. (e.g. Obrázek -> Obrazek)""" + return unicodedata.normalize('NFD', text).encode('ascii', 'ignore').decode('ascii') diff --git a/pero_ocr/document_ocr/page_parser.py b/pero_ocr/document_ocr/page_parser.py index 6f2acc2..b4327ba 100644 --- a/pero_ocr/document_ocr/page_parser.py +++ b/pero_ocr/document_ocr/page_parser.py @@ -4,17 +4,20 @@ from multiprocessing import Pool import math import time +import re +from typing import Union, Tuple, List import torch.cuda -from pero_ocr.utils import compose_path +from pero_ocr.utils import compose_path, config_get_list from pero_ocr.core.layout import PageLayout, RegionLayout, TextLine import pero_ocr.core.crop_engine as cropper +from pero_ocr.core.confidence_estimation import get_line_confidence from pero_ocr.ocr_engine.pytorch_ocr_engine import PytorchEngineLineOCR from pero_ocr.ocr_engine.transformer_ocr_engine import TransformerEngineLineOCR from pero_ocr.layout_engines.simple_region_engine import SimpleThresholdRegion from pero_ocr.layout_engines.simple_baseline_engine import EngineLineDetectorSimple -from pero_ocr.layout_engines.cnn_layout_engine import LayoutEngine, LineFilterEngine +from pero_ocr.layout_engines.cnn_layout_engine import LayoutEngine, LineFilterEngine, LayoutEngineYolo from pero_ocr.layout_engines.line_postprocessing_engine import PostprocessingEngine from pero_ocr.layout_engines.naive_sorter import NaiveRegionSorter from pero_ocr.layout_engines.smart_sorter import SmartRegionSorter @@ -26,14 +29,15 @@ logger = logging.getLogger(__name__) -def layout_parser_factory(config, device, config_path='', order=1): - config = config['LAYOUT_PARSER_{}'.format(order)] +def layout_parser_factory(config, device, config_path=''): if config['METHOD'] == 'REGION_WHOLE_PAGE': layout_parser = WholePageRegion(config, config_path=config_path) elif config['METHOD'] == 'REGION_SIMPLE_THRESHOLD': layout_parser = SimpleThresholdRegion(config, config_path=config_path) elif config['METHOD'] == 'LAYOUT_CNN': layout_parser = LayoutExtractor(config, device, config_path=config_path) + elif config['METHOD'] == 'LAYOUT_YOLO': + layout_parser = LayoutExtractorYolo(config, device, config_path=config_path) elif config['METHOD'] == 'LINES_SIMPLE_THRESHOLD': layout_parser = TextlineExtractorSimple(config, config_path=config_path) elif config['METHOD'] == 'LINE_FILTER': @@ -51,13 +55,11 @@ def layout_parser_factory(config, device, config_path='', order=1): return layout_parser -def line_cropper_factory(config, config_path=''): - config = config['LINE_CROPPER'] +def line_cropper_factory(config, config_path='', device=None): return LineCropper(config, config_path=config_path) def ocr_factory(config, device, config_path=''): - config = config['OCR'] return PageOCR(config, device, config_path=config_path) @@ -71,7 +73,9 @@ def page_decoder_factory(config, device, config_path=''): decoder = decoding_itf.decoder_factory(config['DECODER'], ocr_chars, device, allow_no_decoder=False, config_path=config_path) confidence_threshold = config['DECODER'].getfloat('CONFIDENCE_THRESHOLD', fallback=math.inf) carry_h_over = config['DECODER'].getboolean('CARRY_H_OVER') - return PageDecoder(decoder, line_confidence_threshold=confidence_threshold, carry_h_over=carry_h_over) + categories = config_get_list(config['DECODER'], key='CATEGORIES', fallback=[]) + return PageDecoder(decoder, line_confidence_threshold=confidence_threshold, carry_h_over=carry_h_over, + categories=categories) class MissingLogits(Exception): @@ -94,24 +98,26 @@ def prepare_dense_logits(line): class PageDecoder: - def __init__(self, decoder, line_confidence_threshold=None, carry_h_over=False): + def __init__(self, decoder, line_confidence_threshold=None, carry_h_over=False, categories=None): self.decoder = decoder self.line_confidence_threshold = line_confidence_threshold self.lines_examined = 0 self.lines_decoded = 0 self.seconds_decoding = 0.0 self.continue_lines = carry_h_over + self.categories = categories if categories else ['text'] self.last_h = None self.last_line = None def process_page(self, page_layout: PageLayout): self.last_h = None - for line in page_layout.lines_iterator(): + for line in page_layout.lines_iterator(self.categories): try: line.transcription = self.decode_line(line) except Exception: - logger.error(f'Failed to process line {line.id} of page {page_layout.id}. The page has been processed no further.', exc_info=True) + logger.error(f'Failed to process line {line.id} of page {page_layout.id}. ' + f'The page has been processed no further.', exc_info=True) return page_layout @@ -193,7 +199,8 @@ def process_page(self, img, page_layout: PageLayout): id='{}-l{:03d}'.format(region.id, line_num+1), baseline=baseline, polygon=textline, - heights=heights + heights=heights, + category='text' ) region.lines.append(new_textline) return page_layout @@ -208,6 +215,7 @@ def __init__(self, config, device, config_path=''): self.adjust_heights = config.getboolean('ADJUST_HEIGHTS') self.multi_orientation = config.getboolean('MULTI_ORIENTATION') self.adjust_baselines = config.getboolean('ADJUST_BASELINES') + self.categories = config_get_list(config, key='CATEGORIES', fallback=[]) use_cpu = config.getboolean('USE_CPU') self.device = device if not use_cpu else torch.device("cpu") @@ -227,6 +235,8 @@ def __init__(self, config, device, config_path=''): self.pool = Pool(1) def process_page(self, img, page_layout: PageLayout): + page_layout, page_layout_no_text = helpers.split_page_layout(page_layout) + if self.detect_regions or self.detect_lines: if self.detect_regions: page_layout.regions = [] @@ -248,7 +258,7 @@ def process_page(self, img, page_layout: PageLayout): id = 'r{:03d}_{}'.format(id, rot) else: id = 'r{:03d}'.format(id) - region = RegionLayout(id, polygon) + region = RegionLayout(id, polygon, category='text') regions.append(region) if self.detect_lines: if not self.detect_regions: @@ -283,7 +293,7 @@ def process_page(self, img, page_layout: PageLayout): region = helpers.assign_lines_to_regions(pb_list, ph_list, pt_list, [region])[0] if self.adjust_heights: - for line in page_layout.lines_iterator(): + for line in page_layout.lines_iterator(self.categories): sample_points = helpers.resample_baselines( [line.baseline], num_points=40)[0] line.heights = self.engine.get_heights(maps, ds, sample_points) @@ -293,11 +303,109 @@ def process_page(self, img, page_layout: PageLayout): if self.adjust_baselines: crop_engine = cropper.EngineLineCropper( line_height=32, poly=0, scale=1) - for line in page_layout.lines_iterator(): + for line in page_layout.lines_iterator(self.categories): line.baseline = refine_baseline(line.baseline, line.heights, maps, ds, crop_engine) line.polygon = helpers.baseline_to_textline(line.baseline, line.heights) + page_layout = helpers.merge_page_layouts(page_layout, page_layout_no_text) + return page_layout + + +class LayoutExtractorYolo(object): + def __init__(self, config, device, config_path=''): + try: + import ultralytics # check if ultralytics library is installed + # (ultralytics need different numpy version than some specific version installed on pero-ocr machines) + except ImportError: + raise ImportError("To use LayoutExtractorYolo, you need to install ultralytics library. " + "You can do it by running 'pip install ultralytics'.") + + use_cpu = config.getboolean('USE_CPU') + self.device = device if not use_cpu else torch.device("cpu") + self.categories = config_get_list(config, key='CATEGORIES', fallback=[]) + self.line_categories = config_get_list(config, key='LINE_CATEGORIES', fallback=[]) + self.image_size = self.get_image_size(config) + + self.engine = LayoutEngineYolo( + model_path=compose_path(config['MODEL_PATH'], config_path), + device=self.device, + detection_threshold=config.getfloat('DETECTION_THRESHOLD'), + image_size=self.image_size + ) + + def process_page(self, img, page_layout: PageLayout): + page_layout_text, page_layout = helpers.split_page_layout(page_layout) + page_layout.regions = [] + + result = self.engine.detect(img) + start_id = self.get_start_id([region.id for region in page_layout_text.regions]) + + boxes = result.boxes.data.cpu() + for box_id, box in enumerate(boxes): + id_str = 'r{:03d}'.format(start_id + box_id) + + x_min, y_min, x_max, y_max, conf, class_id = box.tolist() + polygon = np.array([[x_min, y_min], [x_min, y_max], [x_max, y_max], [x_max, y_min], [x_min, y_min]]) + baseline_y = y_min + (y_max - y_min) / 2 + baseline = np.array([[x_min, baseline_y], [x_max, baseline_y]]) + height = np.floor(np.array([baseline_y - y_min, y_max - baseline_y])) + + category = result.names[class_id] + if self.categories and category not in self.categories: + continue + + region = RegionLayout(id_str, polygon, category=category, detection_confidence=conf) + + if category in self.line_categories: + line = TextLine( + id=f'{id_str}-l000', + index=0, + polygon=polygon, + baseline=baseline, + heights=height, + category=category + ) + region.lines.append(line) + page_layout.regions.append(region) + + page_layout = helpers.merge_page_layouts(page_layout_text, page_layout) return page_layout + @staticmethod + def get_image_size(config) -> Union[int, Tuple[int, int], None]: + if 'IMAGE_SIZE' not in config: + return None + + try: + image_size = config.getint('IMAGE_SIZE') + except ValueError: + image_size = config_get_list(config, key='IMAGE_SIZE') + if len(image_size) != 2: + raise ValueError(f'Invalid image size. Expected int or list of two ints, but got: ' + f'{image_size} of type {type(image_size)}') + image_size = image_size[0], image_size[1] + return image_size + + @staticmethod + def get_start_id(used_ids: list) -> int: + """Get int from which to start id naming for new regions. + + Expected region id is in format rXXX, where XXX is number. + """ + used_region_ids = sorted(used_ids) + if not used_region_ids: + return 0 + + ids = [] + for id in used_region_ids: + id = re.match(r'r(\d+)', id).group(1) + try: + ids.append(int(id)) + except ValueError: + pass + + last_used_id = sorted(ids)[-1] + return last_used_id + 1 + class LineFilter(object): def __init__(self, config, device, config_path): @@ -305,6 +413,7 @@ def __init__(self, config, device, config_path): self.filter_incomplete_pages = config.getboolean('FILTER_INCOMPLETE_PAGES') self.filter_pages_with_short_lines = config.getboolean('FILTER_PAGES_WITH_SHORT_LINES') self.length_threshold = config.getint('LENGTH_THRESHOLD') + self.categories = config_get_list(config, key='CATEGORIES', fallback=[]) use_cpu = config.getboolean('USE_CPU') self.device = device if not use_cpu else torch.device("cpu") @@ -326,7 +435,7 @@ def process_page(self, img, page_layout: PageLayout): region.lines = [line for line in region.lines if helpers.check_line_position(line.baseline, page_layout.page_size)] if self.filter_pages_with_short_lines: - b_list = [line.baseline for line in page_layout.lines_iterator()] + b_list = [line.baseline for line in page_layout.lines_iterator(self.categories)] if helpers.get_max_line_length(b_list) < self.length_threshold: page_layout.regions = [] @@ -378,18 +487,20 @@ def __init__(self, config, config_path=''): poly = config.getint('INTERP') line_scale = config.getfloat('LINE_SCALE') line_height = config.getint('LINE_HEIGHT') + self.categories = config_get_list(config, key='CATEGORIES', fallback=[]) self.crop_engine = cropper.EngineLineCropper( line_height=line_height, poly=poly, scale=line_scale) def process_page(self, img, page_layout: PageLayout): - for line in page_layout.lines_iterator(): + for line in page_layout.lines_iterator(self.categories): try: line.crop = self.crop_engine.crop( img, line.baseline, line.heights) except ValueError: line.crop = np.zeros( (self.crop_engine.line_height, self.crop_engine.line_height, 3)) - print(f"WARNING: Failed to crop line {line.id} in page {page_layout.id}. Probably contain vertical line. Contanct Olda Kodym to fix this bug!") + print(f"WARNING: Failed to crop line {line.id} in page {page_layout.id}. " + f"Probably contain vertical line. Contanct Olda Kodym to fix this bug!") return page_layout def crop_lines(self, img, lines: list): @@ -400,39 +511,99 @@ def crop_lines(self, img, lines: list): except ValueError: line.crop = np.zeros( (self.crop_engine.line_height, self.crop_engine.line_height, 3)) - print(f"WARNING: Failed to crop line {line.id}. Probably contain vertical line. Contanct Olda Kodym to fix this bug!") + print(f"WARNING: Failed to crop line {line.id}. Probably contain vertical line. " + f"Contanct Olda Kodym to fix this bug!") class PageOCR: + default_confidence = 0.0 + def __init__(self, config, device, config_path=''): json_file = compose_path(config['OCR_JSON'], config_path) use_cpu = config.getboolean('USE_CPU') self.device = device if not use_cpu else torch.device("cpu") + self.categories = config_get_list(config, key='CATEGORIES', fallback=[]) + self.substitute_output = config.getboolean('SUBSTITUTE_OUTPUT', fallback=True) + self.substitute_output_atomic = config.getboolean('SUBSTITUTE_OUTPUT_ATOMIC', fallback=True) + self.update_transcription_by_confidence = config.getboolean( + 'UPDATE_TRANSCRIPTION_BY_CONFIDENCE', fallback=False) if 'METHOD' in config and config['METHOD'] == "pytorch_ocr-transformer": - self.ocr_engine = TransformerEngineLineOCR(json_file, self.device) + self.ocr_engine = TransformerEngineLineOCR(json_file, self.device, + substitute_output_atomic=self.substitute_output_atomic) else: - self.ocr_engine = PytorchEngineLineOCR(json_file, self.device) + self.ocr_engine = PytorchEngineLineOCR(json_file, self.device, + substitute_output_atomic=self.substitute_output_atomic) def process_page(self, img, page_layout: PageLayout): - for line in page_layout.lines_iterator(): + lines_to_process = [] + for line in page_layout.lines_iterator(self.categories): if line.crop is None: raise Exception(f'Missing crop in line {line.id}.') + lines_to_process.append(line) + + transcriptions, logits, logit_coords = self.ocr_engine.process_lines([line.crop for line in lines_to_process]) - transcriptions, logits, logit_coords = self.ocr_engine.process_lines([line.crop for line in page_layout.lines_iterator()]) + for line, line_transcription, line_logits, line_logit_coords in zip(lines_to_process, transcriptions, + logits, logit_coords): + new_line = TextLine(id=line.id, + transcription=line_transcription, + logits=line_logits, + characters=self.ocr_engine.characters, + logit_coords=line_logit_coords) + new_line.transcription_confidence = self.get_line_confidence(new_line) + + if not self.update_transcription_by_confidence: + self.update_line(line, new_line) + else: + if (line.transcription_confidence in [None, self.default_confidence] or + line.transcription_confidence < new_line.transcription_confidence): + self.update_line(line, new_line) + + if self.substitute_output and self.ocr_engine.output_substitution is not None: + self.substitute_transcriptions(lines_to_process) - for line, line_transcription, line_logits, line_logit_coords in zip(page_layout.lines_iterator(), transcriptions, logits, logit_coords): - line.transcription = line_transcription - line.logits = line_logits - line.characters = self.ocr_engine.characters - line.logit_coords = line_logit_coords return page_layout + def substitute_transcriptions(self, lines_to_process: List[TextLine]): + transcriptions_substituted = [] + + for line in lines_to_process: + transcriptions_substituted.append(self.ocr_engine.output_substitution(line.transcription)) + + if transcriptions_substituted[-1] is None: + if self.substitute_output_atomic: + return # scratch everything if the last line couldn't be substituted atomically + else: + transcriptions_substituted[-1] = line.transcription # keep the original transcription + + for line, transcription_substituted in zip(lines_to_process, transcriptions_substituted): + line.transcription = transcription_substituted + + def get_line_confidence(self, line): + if line.transcription: + try: + log_probs = line.get_full_logprobs()[line.logit_coords[0]:line.logit_coords[1]] + confidences = get_line_confidence(line, log_probs=log_probs) + return np.quantile(confidences, .50) + except (ValueError, IndexError) as e: + logger.warning(f'PageOCR is unable to get confidence of line {line.id} due to exception: {e}.') + return self.default_confidence + return self.default_confidence + @property def provides_ctc_logits(self): return isinstance(self.ocr_engine, PytorchEngineLineOCR) or isinstance(self.ocr_engine, TransformerEngineLineOCR) + @staticmethod + def update_line(line, new_line): + line.transcription = new_line.transcription + line.logits = new_line.logits + line.characters = new_line.characters + line.logit_coords = new_line.logit_coords + line.transcription_confidence = new_line.transcription_confidence + def get_prob(best_ids, best_probs): last_id = -1 @@ -456,6 +627,9 @@ def get_default_device(): class PageParser(object): def __init__(self, config, device=None, config_path='', ): + if not config.sections(): + raise ValueError('Config file is empty or does not exist.') + self.run_layout_parser = config['PAGE_PARSER'].getboolean('RUN_LAYOUT_PARSER', fallback=False) self.run_line_cropper = config['PAGE_PARSER'].getboolean('RUN_LINE_CROPPER', fallback=False) self.run_ocr = config['PAGE_PARSER'].getboolean('RUN_OCR', fallback=False) @@ -463,22 +637,19 @@ def __init__(self, config, device=None, config_path='', ): self.filter_confident_lines_threshold = config['PAGE_PARSER'].getfloat('FILTER_CONFIDENT_LINES_THRESHOLD', fallback=-1) - self.layout_parser = None - self.line_cropper = None - self.ocr = None - self.decoder = None - self.device = device if device is not None else get_default_device() + self.layout_parsers = {} + self.line_croppers = {} + self.ocrs = {} + self.decoder = None + if self.run_layout_parser: - self.layout_parsers = [] - for i in range(1, 10): - if config.has_section('LAYOUT_PARSER_{}'.format(i)): - self.layout_parsers.append(layout_parser_factory(config, self.device, config_path=config_path, order=i)) + self.layout_parsers = self.init_config_sections(config, config_path, 'LAYOUT_PARSER', layout_parser_factory) if self.run_line_cropper: - self.line_cropper = line_cropper_factory(config, config_path=config_path) + self.line_croppers = self.init_config_sections(config, config_path, 'LINE_CROPPER', line_cropper_factory) if self.run_ocr: - self.ocr = ocr_factory(config, self.device, config_path=config_path) + self.ocrs = self.init_config_sections(config, config_path, 'OCR', ocr_factory) if self.run_decoder: self.decoder = page_decoder_factory(config, self.device, config_path=config_path) @@ -497,10 +668,10 @@ def compute_line_confidence(line, threshold=None): @property def provides_ctc_logits(self): - if not self.ocr: + if not self.ocrs: return False - return self.ocr.provides_ctc_logits + return any(ocr.provides_ctc_logits for ocr in self.ocrs.values()) def update_confidences(self, page_layout): for line in page_layout.lines_iterator(): @@ -514,12 +685,15 @@ def filter_confident_lines(self, page_layout): def process_page(self, image, page_layout): if self.run_layout_parser: - for layout_parser in self.layout_parsers: + for _, layout_parser in sorted(self.layout_parsers.items()): page_layout = layout_parser.process_page(image, page_layout) - if self.run_line_cropper: - page_layout = self.line_cropper.process_page(image, page_layout) - if self.run_ocr: - page_layout = self.ocr.process_page(image, page_layout) + + merged_keys = set(self.line_croppers.keys()) | set(self.ocrs.keys()) + for key in sorted(merged_keys): + if self.run_line_cropper and key in self.line_croppers: + page_layout = self.line_croppers[key].process_page(image, page_layout) + if self.run_ocr and key in self.ocrs: + page_layout = self.ocrs[key].process_page(image, page_layout) if self.run_decoder: page_layout = self.decoder.process_page(page_layout) @@ -529,3 +703,37 @@ def process_page(self, image, page_layout): page_layout = self.filter_confident_lines(page_layout) return page_layout + + def init_config_sections(self, config, config_path, section_name, section_factory) -> dict: + """Return dict of sections. + + Naming convention: section_name_[0-9]+. + Also accepts other names, but logges warning. + e.g. for OCR section: OCR, OCR_0, OCR_42_asdf, OCR_99_last_one...""" + sections = {} + if section_name in config.sections(): + sections['-1'] = section_name + + section_names = [config_section for config_section in config.sections() + if re.match(rf'{section_name}_(\d+)', config_section)] + section_names = sorted(section_names) + + for config_section in section_names: + section_id = config_section.replace(section_name + '_', '') + try: + int(section_id) + except ValueError: + logger.warning( + f'Warning: section name {config_section} does not follow naming convention. ' + f'Use only {section_name}_[0-9]+.') + sections[section_id] = config_section + + if 0 in sections.keys() and -1 in sections.keys(): + logger.warning(f'Warning: sections {sections[0]} and {sections[-1]} are both present. ' + f'Use only names following {section_name}_[0-9]+ convention.') + + for section_id, section_full_name in sections.items(): + sections[section_id] = section_factory(config[section_full_name], + config_path=config_path, device=self.device) + + return sections diff --git a/pero_ocr/layout_engines/cnn_layout_engine.py b/pero_ocr/layout_engines/cnn_layout_engine.py index 51c2515..9455fc4 100644 --- a/pero_ocr/layout_engines/cnn_layout_engine.py +++ b/pero_ocr/layout_engines/cnn_layout_engine.py @@ -1,6 +1,7 @@ import numpy as np from copy import deepcopy import time +from typing import Union, Tuple import cv2 from scipy import ndimage @@ -8,6 +9,7 @@ from scipy.sparse.csgraph import connected_components import skimage.draw import shapely.geometry as sg +import torch from pero_ocr.layout_engines import layout_helpers as helpers from pero_ocr.layout_engines.torch_parsenet import TorchParseNet, TorchOrientationNet @@ -371,6 +373,37 @@ def make_clusters(self, b_list, h_list, t_list, layout_separator_map, ds): else: return [0] + +class LayoutEngineYolo(object): + def __init__(self, model_path, device, + image_size: Union[int, Tuple[int, int], None] = None, + detection_threshold=0.2): + from ultralytics import YOLO # import here, only if needed + # (ultralytics need different numpy version than some specific version installed on pero-ocr machines) + + self.yolo_net = YOLO(model_path).to(device) + self.detection_threshold = detection_threshold + self.image_size = image_size # height or (height, width) + + def detect(self, image): + """Uses yolo_net to find bounding boxes. + :param image: input image + """ + if self.image_size is not None: + results = self.yolo_net(image, + conf=self.detection_threshold, + imgsz=self.image_size, + verbose=False) + else: + results = self.yolo_net(image, + conf=self.detection_threshold, + verbose=False) + + if results is None: + raise Exception('Yolo inference returned None.') + return results[0] + + def nonmaxima_suppression(input, element_size=(7, 1)): """Vertical non-maxima suppression. :param input: input array diff --git a/pero_ocr/layout_engines/layout_helpers.py b/pero_ocr/layout_engines/layout_helpers.py index 0efe182..27f2ed7 100644 --- a/pero_ocr/layout_engines/layout_helpers.py +++ b/pero_ocr/layout_engines/layout_helpers.py @@ -1,6 +1,8 @@ import math import random import warnings +from copy import deepcopy +from typing import Tuple, Optional import numpy as np import cv2 @@ -10,7 +12,7 @@ import shapely.geometry as sg from shapely.ops import unary_union, polygonize -from pero_ocr.core.layout import TextLine +from pero_ocr.core.layout import PageLayout, RegionLayout, TextLine def check_line_position(baseline, page_size, margin=20, min_ratio=0.125): @@ -68,7 +70,8 @@ def assign_lines_to_regions(baseline_list, heights_list, textline_list, regions) id='{}-l{:03d}'.format(region.id, line_id+1), baseline=baseline_intersection, polygon=textline_intersection, - heights=heights + heights=heights, + category='text' ) region.lines.append(new_textline) @@ -408,3 +411,81 @@ def adjust_baselines_to_intensity(baselines, img, tolerance=5): baseline_pts[:, 1] += best_offset new_baselines.append(resample_baselines([baseline_pts], num_points=len(baseline))[0]) return new_baselines + + +def split_page_layout(page_layout: PageLayout) -> Tuple[PageLayout, PageLayout]: + """Split page layout to text and non-text regions.""" + return split_page_layout_by_categories(page_layout, ['text']) + + +def split_page_layout_by_categories(page_layout: PageLayout, categories: list) -> Tuple[PageLayout, PageLayout]: + """Split page_layout into two by region category. Return one page_layout with regions of given categories and one with + regions of other categories. No region category is treated as 'text' for backwards compatibility. + If no categories, return original page_layout and empty page_layout. + ! TextLine categories are ignored here ! + + Example: + split_page_layout_by_categories(page_layout, ['text']) + IN: PageLayout(regions=[ + RegionLayout(id='r001', category='text', lines=[TextLine(id='r001-l001', category='text'), + TextLine(id='r001-l002', category='logo')]), + RegionLayout(id='r002', category='image', lines=[TextLine(id='r002-l001', category='text')])]) + OUT: PageLayout(regions=[ + RegionLayout(id='r001', category='text', lines=[TextLine(id='r001-l001', category='text'), + TextLine(id='r001-l002', category='logo')])]) + PageLayout(regions=[ + RegionLayout(id='r002', category='image', lines=[TextLine(id='r002-l001', category='text')])]) + """ + if not categories: + # if no categories, return original page_layout and empty page_layout + page_layout_no_regions = deepcopy(page_layout) + page_layout_no_regions.regions = [] + return page_layout, page_layout_no_regions + + regions = page_layout.regions + page_layout.regions = [] + + page_layout_positive = page_layout + page_layout_negative = deepcopy(page_layout) + + for region in regions: + region_category = region.category if region.category is not None else 'text' + if region_category in categories: + page_layout_positive.regions.append(region) + else: + page_layout_negative.regions.append(region) + return page_layout_positive, page_layout_negative + + +def merge_page_layouts(page_layout_positive: PageLayout, page_layout_negative: PageLayout) -> PageLayout: + """Merge two page_layouts into one by regions. If same region ID, create new ID (rename line IDs also). + + Example: + IN: PageLayout(regions=[ + RegionLayout(id='r001', lines=[TextLine(id='r001-l001', category='text')])]), + PageLayout(regions=[ + RegionLayout(id='r001', lines=[TextLine(id='r001-l002', category='logo')])]) + OUT: PageLayout(regions=[ + RegionLayout(id='r001', lines=[TextLine(id='r001-l001', category='text')]), + RegionLayout(id='r001-1', lines=[TextLine(id='r001-1-l002', category='logo')])]) + """ + used_region_ids = set(region.id for region in page_layout_positive.regions) + + for region in page_layout_negative.regions: + if region.id not in used_region_ids: + used_region_ids.add(region.id) + page_layout_positive.regions.append(region) + else: + new_region_id = region.id + id_offset = 1 + + # find new unique region ID by adding offset + while new_region_id in used_region_ids: + new_region_id = region.id + '-' + str(id_offset) + id_offset += 1 + + region.replace_id(new_region_id) + used_region_ids.add(new_region_id) + page_layout_positive.regions.append(region) + + return page_layout_positive diff --git a/pero_ocr/layout_engines/naive_sorter.py b/pero_ocr/layout_engines/naive_sorter.py index 36d2357..3cd5932 100644 --- a/pero_ocr/layout_engines/naive_sorter.py +++ b/pero_ocr/layout_engines/naive_sorter.py @@ -6,7 +6,9 @@ from sklearn.cluster import DBSCAN from typing import List +from pero_ocr.utils import config_get_list from pero_ocr.core.layout import PageLayout, RegionLayout +from pero_ocr.layout_engines import layout_helpers as helpers class Region: @@ -42,10 +44,11 @@ class NaiveRegionSorter: def __init__(self, config: SectionProxy, config_path=""): # minimal distance between clusters = page_width / width_denom self.width_denom = config.getint('ImageWidthDenominator', fallback=10) + self.categories = config_get_list(config, key='CATEGORIES', fallback=[]) def process_page(self, image, page_layout: PageLayout): + page_layout, page_layout_ignore = helpers.split_page_layout_by_categories(page_layout, self.categories) regions = [] - for region in page_layout.regions: regions.append(Region(region)) @@ -54,6 +57,7 @@ def process_page(self, image, page_layout: PageLayout): page_layout.regions = [page_layout.regions[idx] for idx in order] + page_layout = helpers.merge_page_layouts(page_layout_ignore, page_layout) return page_layout @staticmethod diff --git a/pero_ocr/layout_engines/smart_sorter.py b/pero_ocr/layout_engines/smart_sorter.py index a99c995..2b671ad 100644 --- a/pero_ocr/layout_engines/smart_sorter.py +++ b/pero_ocr/layout_engines/smart_sorter.py @@ -11,6 +11,8 @@ from typing import List, Dict, Union, Optional from pero_ocr.core.layout import PageLayout, RegionLayout +from pero_ocr.utils import config_get_list +from pero_ocr.layout_engines import layout_helpers as helpers def pairwise(iterable): @@ -275,11 +277,14 @@ class SmartRegionSorter: def __init__(self, config: SectionProxy, config_path=""): # if intersection of two regions is less than given parameter w.r.t. both regions, intersection doesn't count self.intersect_param = config.getfloat('FakeIntersectionParameter', fallback=0.1) + self.categories = config_get_list(config, key='CATEGORIES', fallback=[]) def process_page(self, image, page_layout: PageLayout): + page_layout, page_layout_ignore = helpers.split_page_layout_by_categories(page_layout, self.categories) regions = [] if len(page_layout.regions) < 2: + page_layout = helpers.merge_page_layouts(page_layout_ignore, page_layout) return page_layout rotation = SmartRegionSorter.get_rotation(max(*page_layout.regions, key=lambda reg: len(reg.lines)).lines) @@ -300,6 +305,7 @@ def process_page(self, image, page_layout: PageLayout): page_layout.regions = [page_layout.regions[idx] for idx in region_idxs] page_layout = SmartRegionSorter.rotate_page_layout(page_layout, rotation) + page_layout = helpers.merge_page_layouts(page_layout_ignore, page_layout) return page_layout @staticmethod diff --git a/pero_ocr/music/README.md b/pero_ocr/music/README.md new file mode 100644 index 0000000..e790b2c --- /dev/null +++ b/pero_ocr/music/README.md @@ -0,0 +1,7 @@ +# README.md + +This folder contains scripts for exporting transcribed musical pages. +CLI tool with documentation is in `user_scripts/export_music.py`. +Main functionality is in `export_music.py/ExportMusicPage`. + +For older versions of these files, see [github.com/vlachvojta/polyphonic-omr-by-sachindae](https://github.com/vlachvojta/polyphonic-omr-by-sachindae/tree/main/reverse_converter) diff --git a/pero_ocr/music/__init__.py b/pero_ocr/music/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pero_ocr/music/music_exporter.py b/pero_ocr/music/music_exporter.py new file mode 100644 index 0000000..0b4438c --- /dev/null +++ b/pero_ocr/music/music_exporter.py @@ -0,0 +1,335 @@ +from __future__ import annotations +import sys +import os +import re +import logging +from typing import List + +import music21 as music + +from pero_ocr.core.layout import PageLayout, RegionLayout, TextLine +from pero_ocr.layout_engines.layout_helpers import split_page_layout_by_categories +from pero_ocr.music.music_structures import Measure +from pero_ocr.music.output_translator import OutputTranslator as Translator + + +class MusicPageExporter: + """Take pageLayout XML exported from pero-ocr with transcriptions and re-construct page of musical notation. + + For CLI usage see user_scripts/music_exporter.py + """ + + def __init__(self, input_xml_path: str = '', input_transcription_files: List[str] = None, + translator_path: str = None, output_folder: str = 'output_page', export_midi: bool = False, + export_musicxml: bool = False, categories: List = None, verbose: bool = False): + self.translator_path = translator_path + if verbose: + logging.basicConfig(level=logging.DEBUG, format='[%(levelname)-s] \t- %(message)s') + else: + logging.basicConfig(level=logging.INFO, format='[%(levelname)-s]\t- %(message)s') + self.verbose = verbose + + if input_xml_path and not os.path.isfile(input_xml_path): + logging.error('No input file of this path was found') + self.input_xml_path = input_xml_path + + self.input_transcription_files = input_transcription_files if input_transcription_files else [] + + if not os.path.exists(output_folder): + os.makedirs(output_folder) + self.output_folder = output_folder + self.export_midi = export_midi + self.export_musicxml = export_musicxml + + self.translator = Translator(filename=self.translator_path) if translator_path else None + self.categories = categories if categories else ['Notový zápis'] + + def __call__(self, page_layout=None) -> None: + if self.input_transcription_files: + MusicLinesExporter(input_files=self.input_transcription_files, output_folder=self.output_folder, + translator=self.translator, verbose=self.verbose)() + if page_layout: + self.process_page(page_layout) + + if self.input_xml_path: + input_page_layout = PageLayout(file=self.input_xml_path) + self.export_page_layout(input_page_layout) + + def process_page(self, page_layout: PageLayout) -> None: + self.export_page_layout(page_layout, page_layout.id) + + def export_page_layout(self, page_layout: PageLayout, file_id: str = None) -> None: + if self.export_musicxml or self.export_midi: + page_layout, _ = split_page_layout_by_categories(page_layout, self.categories) + parts = self.regions_to_parts( + page_layout.regions) + if not parts: + return + + music_parts = [] + for part in parts: + music_parts.append(part.encode_to_music21()) + + # Finalize score creation + metadata = music.metadata.Metadata() + metadata.title = metadata.composer = '' + score = music.stream.Score([metadata] + music_parts) + + if self.export_musicxml: + output_file = self.get_output_file(file_id, extension='musicxml') + xml = music21_to_musicxml(score) + write_to_file(output_file, xml) + + if self.export_midi: + self.export_to_midi(score, parts, file_id) + + def get_output_file(self, file_id: str = None, extension: str = 'musicxml') -> str: + base = self.get_output_file_base(file_id) + return f'{base}.{extension}' + + def get_output_file_base(self, file_id: str = None) -> str: + if not file_id: + file_id = os.path.basename(self.input_xml_path) + if not file_id: + file_id = 'output' + name, *_ = re.split(r'\.', file_id) + return os.path.join(self.output_folder, f'{name}') + + def export_to_midi(self, score, parts, file_id: str = None): + # Export whole score to midi + output_file = self.get_output_file(file_id, extension='mid') + score.write("midi", output_file) + + for part in parts: + base = self.get_output_file_base(file_id) + part.export_to_midi(base) + + def regions_to_parts(self, regions: List[RegionLayout]) -> List[Part]: + """Takes a list of regions and splits them to parts.""" + max_parts = max( + [len(region.get_lines_of_category(self.categories)) for region in regions], + default=0 + ) + if max_parts == 0: + print('Warning: No music lines found in page.') + return [] + + parts = [Part(self.translator) for _ in range(max_parts)] + for region in regions: + for part, line in zip(parts, region.get_lines_of_category(self.categories)): + part.add_textline(line) + + return parts + + +class MusicLinesExporter: + """Takes text files with transcriptions as individual lines and exports musicxml file for each one""" + def __init__(self, input_files: List[str] = None, output_folder: str = 'output_musicxml', + translator: Translator = None, verbose: bool = False): + self.translator = translator + self.output_folder = output_folder + + if verbose: + logging.basicConfig(level=logging.DEBUG, format='[%(levelname)-s] \t- %(message)s') + else: + logging.basicConfig(level=logging.INFO, format='[%(levelname)-s]\t- %(message)s') + + logging.debug('Hello World! (from ReverseConverter)') + + self.input_files = MusicLinesExporter.get_input_files(input_files) + MusicLinesExporter.prepare_output_folder(output_folder) + + def __call__(self): + if not self.input_files: + logging.error('No input files provided. Exiting...') + sys.exit(1) + + # For every file, convert it to MusicXML + for input_file_name in self.input_files: + logging.info(f'Reading file {input_file_name}') + lines = MusicLinesExporter.read_file_lines(input_file_name) + + for i, line in enumerate(lines): + match = re.fullmatch(r'([a-zA-Z0-9_\-]+)[a-zA-Z0-9_\.]+\s+([0-9]+\s+)?\"([\S\s]+)\"', line) + + if not match: + logging.debug(f'NOT MATCHING PATTERN. Skipping line {i} in file {input_file_name}: ' + f'({line[:min(50, len(line))]}...)') + continue + + stave_id = match.group(1) + labels = match.group(3) + if self.translator is not None: + labels = self.translator.translate_line(labels) + output_file_name = os.path.join(self.output_folder, f'{stave_id}.musicxml') + + parsed_labels = semantic_line_to_music21_score(labels) + if not isinstance(parsed_labels, music.stream.Stream): + logging.error(f'Labels could not be parsed. Skipping line {i} in file {input_file_name}: ' + f'({line[:min(50, len(line))]}...)') + continue + + logging.info(f'Parsing successfully completed.') + # parsed_labels.show() # Show parsed labels in some visual program (MuseScore by default) + + xml = music21_to_musicxml(parsed_labels) + write_to_file(output_file_name, xml) + + @staticmethod + def prepare_output_folder(output_folder: str): + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + @staticmethod + def get_input_files(input_files: List[str] = None): + existing_files = [] + + if not input_files: + return [] + + for input_file in input_files: + if os.path.isfile(input_file): + existing_files.append(input_file) + + return existing_files + + @staticmethod + def read_file_lines(input_file: str) -> List[str]: + with open(input_file, 'r', encoding='utf-8') as f: + lines = f.read().splitlines() + + if not lines: + logging.warning(f'File {input_file} is empty!') + + return [line for line in lines if line] + + +class Part: + """Represent musical part (part of notation for one instrument/section)""" + + def __init__(self, translator: Translator = None): + self.translator = translator + + self.repr_music21 = music.stream.Part([music.instrument.Piano()]) # Default instrument is piano + self.labels: List[str] = [] + self.textlines: List[TextLineWrapper] = [] + self.measures: List[Measure] = [] # List of measures in internal representation, NOT music21 + + def add_textline(self, line: TextLine) -> None: + labels = line.transcription + if self.translator is not None: + labels = self.translator.translate_line(labels) + self.labels.append(labels) + + new_measures = parse_semantic_to_measures(labels) + + # Delete first clef symbol of first measure in line if same as last clef in previous line + if len(self.measures) and new_measures[0].get_start_clef() == self.measures[-1].last_clef: + new_measures[0].delete_clef_symbol() + + new_measures_encoded = encode_measures(new_measures, len(self.measures) + 1) + new_measures_encoded_without_measure_ids = encode_measures(new_measures) + + self.measures += new_measures + self.repr_music21.append(new_measures_encoded) + + self.textlines.append(TextLineWrapper(line, new_measures_encoded_without_measure_ids)) + + def encode_to_music21(self) -> music.stream.Part: + if self.repr_music21 is None: + logging.info('Part empty') + + return self.repr_music21 + + def export_to_midi(self, file_base: str): + for text_line in self.textlines: + text_line.export_midi(file_base) + + +class TextLineWrapper: + """Class to wrap one TextLine for easier export etc.""" + def __init__(self, text_line: TextLine, measures: List[music.stream.Measure]): + self.text_line = text_line + self.repr_music21 = music.stream.Part([music.instrument.Piano()] + measures) + + def export_midi(self, file_base: str = 'out'): + # do not export line, if it has no notes + if not any([note for note in self.repr_music21.flatten().notes]): + return + + filename = f'{file_base}_{self.text_line.id}.mid' + + xml = music21_to_musicxml(self.repr_music21) + parsed_xml = music.converter.parse(xml) + parsed_xml.write('mid', filename) + + +def parse_semantic_to_measures(labels: str) -> List[Measure]: + """Convert line of semantic labels to list of measures. + + Args: + labels (str): one line of labels in semantic format without any prefixes. + """ + labels = labels.strip('"') + + measures_labels = re.split(r'barline', labels) + + stripped_measures_labels = [] + for measure_label in measures_labels: + stripped = measure_label.strip().strip('+').strip() + if stripped: + stripped_measures_labels.append(stripped) + + measures = [Measure(measure_label) for measure_label in stripped_measures_labels if measure_label] + + previous_measure_key = music.key.Key() # C Major as a default key (without accidentals) + for measure in measures: + previous_measure_key = measure.get_key(previous_measure_key) + + measures[0].new_system = True + + previous_measure_last_clef = measures[0].get_last_clef() + for measure in measures[1:]: + previous_measure_last_clef = measure.get_last_clef(previous_measure_last_clef) + + return measures + + +def encode_measures(measures: List, measure_id_start_from: int = 1) -> List[Measure]: + """Get list of measures and encode them to music21 encoded measures.""" + logging.debug('-------------------------------- -------------- --------------------------------') + logging.debug('-------------------------------- START ENCODING --------------------------------') + logging.debug('-------------------------------- -------------- --------------------------------') + + measures_encoded = [] + for measure_id, measure in enumerate(measures): + measures_encoded.append(measure.encode_to_music21()) + measures_encoded[-1].number = measure_id_start_from + measure_id + + return measures_encoded + + +def semantic_line_to_music21_score(labels: str) -> music.stream.Score: + """Get semantic line of labels, Return stream encoded in music21 score format.""" + measures = parse_semantic_to_measures(labels) + measures_encoded = encode_measures(measures) + + # stream = music.stream.Score(music.stream.Part([music.instrument.Piano()] + measures_encoded)) + metadata = music.metadata.Metadata() + metadata.title = metadata.composer = '' + stream = music.stream.Score([metadata, music.stream.Part([music.instrument.Piano()] + measures_encoded)]) + + return stream + + +def music21_to_musicxml(music_object): + out_bytes = music.musicxml.m21ToXml.GeneralObjectExporter(music_object).parse() + out_str = out_bytes.decode('utf-8') + return out_str.strip() + + +def write_to_file(output_file_name, xml): + with open(output_file_name, 'w', encoding='utf-8') as f: + f.write(xml) + + logging.info(f'File {output_file_name} successfully written.') diff --git a/pero_ocr/music/music_structures.py b/pero_ocr/music/music_structures.py new file mode 100644 index 0000000..0146a20 --- /dev/null +++ b/pero_ocr/music/music_structures.py @@ -0,0 +1,516 @@ +#!/usr/bin/python3.10 +"""Script for converting semantic sequential representation of music labels (produced by the model) to music21 stream +usable by music21 library to export to other formats. + +Author: Vojtěch Vlach +Contact: xvlach22@vutbr.cz +""" + +from __future__ import annotations +import re +from enum import Enum +import logging +from typing import Optional, List + +import numpy as np +import music21 as music +from pero_ocr.music.music_symbols import Symbol, SymbolType, AlteredPitches, LENGTH_TO_SYMBOL + + +class Measure: + _is_polyphonic = None + keysignature = None + repr = None + new_system = False + start_clef = None + last_clef = None + + def __init__(self, labels: str): + """Takes labels corresponding to a single measure.""" + self.labels = labels + + label_groups = re.split(r'\+', self.labels) + stripped_label_groups = [] + for measure_label in label_groups: + stripped = measure_label.strip().strip('+').strip() + if stripped: + stripped_label_groups.append(stripped) + + self.symbol_groups = [SymbolGroup(label_group) for label_group in stripped_label_groups] + self._is_polyphonic = self.is_polyphonic + + def __str__(self): + label_groups_str = '\n'.join([str(group) for group in self.symbol_groups]) + poly = 'polyphonic' if self.is_polyphonic else 'monophonic' + return (f'MEASURE: ({poly}) \n' + f'key signature: {self.keysignature}\n' + f'labels: {self.labels}\n' + f'{label_groups_str}') + + @property + def is_polyphonic(self) -> bool: + """Returns True if there are more than 1 notes in the same label_group with different lengths.""" + if self._is_polyphonic is not None: + return self._is_polyphonic + + self._is_polyphonic = any(group.type == SymbolGroupType.TUPLE for group in self.symbol_groups) + return self._is_polyphonic + + def get_key(self, previous_measure_key: music.key.Key) -> music.key.Key: + """Returns the key of the measure. + + Args: + previous_measure_key (music.key.Key): key of the previous measure. + + Returns: + music.key.Key: key of the current measure. + """ + if self.keysignature is not None: + return self.keysignature + + for symbol_group in self.symbol_groups: + key = symbol_group.get_key() + if key is not None: + self.set_key(key) + break + else: + self.set_key(previous_measure_key) + return self.keysignature + + def get_start_clef(self) -> Optional[music.clef.Clef]: + if self.start_clef is not None: + return self.start_clef + else: + return self.symbol_groups[0].get_clef() + + def get_last_clef(self, previous_measure_last_clef: music.clef.Clef = music.clef.TrebleClef + ) -> Optional[music.clef.Clef]: + if self.last_clef is not None: + return self.last_clef + + self.last_clef = previous_measure_last_clef + + for group in self.symbol_groups: + new_clef = group.get_clef() + if new_clef is not None: + self.last_clef = new_clef + + return self.last_clef + + def delete_clef_symbol(self, position: int = 0) -> None: + self.symbol_groups.pop(position) + + def set_key(self, key: music.key.Key): + """Sets the key of the measure. Send key to all symbols groups to represent notes in real height. + + Args: + key (music.key.Key): key of the current measure. + """ + self.keysignature = key + + altered_pitches = AlteredPitches(key) + for symbol_group in self.symbol_groups: + symbol_group.set_key(altered_pitches) + + def encode_to_music21(self) -> music.stream.Measure: + """Encodes the measure to music21 format. + + Returns: + music.stream.Measure: music21 representation of the measure. + """ + if self.repr is not None: + return self.repr + + self.repr = music.stream.Measure() + if not self._is_polyphonic: + for symbol_group in self.symbol_groups: + encoded_group = symbol_group.encode_to_music21_monophonic() + if encoded_group is not None: + self.repr.append(encoded_group) + else: + self.repr = self.encode_to_music21_polyphonic() + + if self.new_system: + self.repr.insert(0, music.layout.SystemLayout(isNew=True)) + + logging.debug('Current measure:') + logging.debug(str(self)) + + logging.debug('Current measure ENCODED:') + return self.repr + + def encode_to_music21_polyphonic(self) -> music.stream.Measure: + """Encodes POLYPHONIC MEASURE to music21 format. + + Returns: + music.stream.Measure: music21 representation of the measure. + """ + voice_count = max([symbol_group.get_voice_count() for symbol_group in self.symbol_groups]) + voices = [Voice() for _ in range(voice_count)] + logging.debug('-------------------------------- NEW MEASURE --------------------------------') + logging.debug(f'voice_count: {voice_count}') + + zero_length_symbol_groups = Measure.find_zero_length_symbol_groups(self.symbol_groups) + remaining_symbol_groups = self.symbol_groups[len(zero_length_symbol_groups):] + + mono_start_symbol_groups = Measure.get_mono_start_symbol_groups(remaining_symbol_groups) + voices = Measure.create_mono_start(voices, mono_start_symbol_groups) + remaining_symbol_groups = remaining_symbol_groups[len(mono_start_symbol_groups):] + + # Groups to voices + for symbol_group in remaining_symbol_groups: + logging.debug('------------ NEW symbol_group ------------------------') + groups_to_add = symbol_group.get_groups_to_add() + shortest_voice_ids = Measure.pad_voices_to_n_shortest(voices, len(groups_to_add)) + + logging.debug( + f'Zipping {len(groups_to_add)} symbol groups to shortest voices ({len(shortest_voice_ids)}): {shortest_voice_ids}') + for voice_id, group in zip(shortest_voice_ids, groups_to_add): + logging.debug(f'Voice ({voice_id}) adding: {group}') + voices[voice_id].add_symbol_group(group) + + for voice_id, voice in enumerate(voices): + logging.debug(f'voice ({voice_id}) len: {voice.length}') + + zero_length_encoded = [group.encode_to_music21_monophonic() for group in zero_length_symbol_groups] + zero_length_encoded = [group for group in zero_length_encoded if group is not None] + voices_repr = [voice.encode_to_music21_monophonic() for voice in voices] + voices_repr = [voice for voice in voices_repr if voice is not None] + + if len(voices_repr) == 0: + logging.warning('No voices in the measure, returning empty measure.') + return music.stream.Measure(zero_length_encoded) + + return music.stream.Measure(zero_length_encoded + voices_repr) + + @staticmethod + def find_shortest_voices(voices: List, ignore: List = None) -> List[int]: + """Go through all voices and find the one with the current shortest duration. + + Args: + voices (List): List of voices. + ignore (List): indexes of voices to ignore. + + Returns: + List: indexes of voices with the current shortest duration. + """ + if ignore is None: + ignore = [] + + shortest_duration = 1_000_000 + shortest_voice_ids = [0] + for voice_id, voice in enumerate(voices): + if voice_id in ignore: + continue + if voice.length < shortest_duration: + shortest_duration = voice.length + shortest_voice_ids = [voice_id] + elif voice.length == shortest_duration: + shortest_voice_ids.append(voice_id) + + return shortest_voice_ids + + @staticmethod + def find_zero_length_symbol_groups(symbol_groups: List[SymbolGroup]) -> List[SymbolGroup]: + """Returns a List of zero-length symbol groups AT THE BEGGING OF THE MEASURE.""" + zero_length_symbol_groups = [] + for symbol_group in symbol_groups: + if symbol_group.type == SymbolGroupType.TUPLE or symbol_group.length > 0: + break + zero_length_symbol_groups.append(symbol_group) + return zero_length_symbol_groups + + @staticmethod + def pad_voices_to_n_shortest(voices: List[Voice], n: int = 1) -> List[int]: + """Pads voices (starting from the shortest) so there is n shortest voices with same length. + + Args: + voices (List): List of voices. + n (int): number of desired shortest voices. + + Returns: + List: List of voice IDS with the current shortest duration. + """ + shortest_voice_ids = Measure.find_shortest_voices(voices) + + while n > len(shortest_voice_ids): + logging.debug(f'Found {len(shortest_voice_ids)} shortest voices, desired voices: {n}.') + second_shortest_voice_ids = Measure.find_shortest_voices(voices, ignore=shortest_voice_ids) + second_shortest_len = voices[second_shortest_voice_ids[0]].length + for voice_id in shortest_voice_ids: + desired_padding_length = second_shortest_len - voices[voice_id].length + voices[voice_id].add_padding(desired_padding_length) + + shortest_voice_ids = Measure.find_shortest_voices(voices) + + return shortest_voice_ids + + @staticmethod + def get_mono_start_symbol_groups(symbol_groups: List[SymbolGroup]) -> List[SymbolGroup]: + """Get a List of monophonic symbol groups AT THE BEGINNING OF THE MEASURE. + + Returns: + List: List of monophonic symbol groups AT THE BEGINNING OF THE MEASURE + """ + mono_start_symbol_groups = [] + for symbol_group in symbol_groups: + if symbol_group.type == SymbolGroupType.TUPLE: + break + mono_start_symbol_groups.append(symbol_group) + return mono_start_symbol_groups + + @staticmethod + def create_mono_start(voices: List[Voice], mono_start_symbol_groups: List[SymbolGroup]) -> List[Voice]: + """Create monophonic start of measure in the first voice and add padding to the others. + + Args: + voices (List[Voices]): List of voices + mono_start_symbol_groups: List of monophonic symbol groups AT THE BEGINNING OF MEASURE. + + Returns: + List[Voice]: List of voices + """ + padding_length = 0 + for symbol_group in mono_start_symbol_groups: + voices[0].add_symbol_group(symbol_group) + padding_length += symbol_group.length + + for voice in voices[1:]: + voice.add_padding(padding_length) + + return voices + + +class SymbolGroupType(Enum): + SYMBOL = 0 + CHORD = 1 + TUPLE = 2 + EMPTY = 3 + UNKNOWN = 99 + + +class SymbolGroup: + """Represents one label group in a measure. Consisting of 1 to n labels/symbols.""" + tuple_data: List = None # Tuple data consists of a List of symbol groups where symbols have same lengths. + length: float = None # Length of the symbol group in quarter notes. + + def __init__(self, labels: str): + self.labels = labels + self.type = SymbolGroupType.UNKNOWN + + label_group_parsed = re.split(r'\s', self.labels.strip()) + self.symbols = [Symbol(label_group) for label_group in label_group_parsed if label_group] + + self.type = self.get_type() + if self.type == SymbolGroupType.TUPLE: + self.create_tuple_data() + + def __str__(self): + if not self.type == SymbolGroupType.TUPLE: + symbols_str = '\n'.join([str(symbol) for symbol in self.symbols]) + return (f'\t({self.type}) {self.labels} (len: {self.length}) =>\n' + f'{symbols_str}') + out = [] + for group in self.tuple_data: + out.append(str(group)) + out = '\n'.join(out) + + return (f'\tTUPLE BEGIN:\n' + f'{out}\n' + f'\tTUPLE END') + + def get_type(self): + if len(self.symbols) == 0: + logging.warning(f'No symbols found in label group: {self.labels}') + return SymbolGroupType.UNKNOWN + elif len(self.symbols) == 1: + self.length = self.symbols[0].get_length() + return SymbolGroupType.SYMBOL + else: + same_length_notes = all((symbol.get_length() == self.symbols[0].get_length() and + symbol.type in [SymbolType.NOTE, SymbolType.GRACENOTE]) + for symbol in self.symbols) + if same_length_notes: + self.length = self.symbols[0].get_length() + return SymbolGroupType.CHORD + else: + return SymbolGroupType.TUPLE + + def get_key(self) -> Optional[music.key.Key]: + """Go through all labels and find key signature or return None. + + Returns: + music.key.Key: key signature of the label group or None. + """ + if not self.type == SymbolGroupType.SYMBOL: + return None + + for symbol in self.symbols: + if symbol.type == SymbolType.KEY_SIGNATURE: + return symbol.repr + + return None + + def set_key(self, altered_pitches: AlteredPitches): + if self.type == SymbolGroupType.TUPLE: + for group in self.tuple_data: + group.set_key(altered_pitches) + else: + for symbol in self.symbols: + symbol.set_key(altered_pitches) + + def encode_to_music21_monophonic(self) -> Optional[music.stream.Stream]: + """Encodes the label group to music21 format. + + Returns: + music.object: music21 representation of the label group. + """ + if self.type == SymbolGroupType.SYMBOL: + return self.symbols[0].repr + elif self.type == SymbolGroupType.CHORD: + notes = [symbol.repr for symbol in self.symbols] + logging.debug(f'notes: {notes}') + return music.chord.Chord(notes) + # return music.stream.Stream(music.chord.Chord(notes)) + elif self.type == SymbolGroupType.EMPTY: + return music.stream.Stream() + elif self.type == SymbolGroupType.TUPLE: + logging.info(f'Tuple label group not supported yet, returning empty stream.') + return music.stream.Stream() + else: + return music.stream.Stream() + + def create_tuple_data(self): + """Create tuple data for the label group. + + Tuple data consists of a List of symbol groups where symbols have same lengths. + """ + # logging.debug(f'Creating tuple data for label group: {self.labels}') + list_of_groups = [[self.symbols[0]]] + for symbol in self.symbols[1:]: + if symbol.type == SymbolType.REST: + list_of_groups.append([symbol]) + continue + symbol_length = symbol.get_length() + for group in list_of_groups: + # if symbol_length == group[0].get_length() and symbol.type in [SymbolType.NOTE, SymbolType.GRACENOTE]: + if group[0].type in [SymbolType.NOTE, SymbolType.GRACENOTE] and symbol_length == group[0].get_length(): + group.append(symbol) + break + else: + list_of_groups.append([symbol]) + + # logging.debug(list_of_groups) + + self.tuple_data = [] + for group in list_of_groups: + labels = [symbol.label for symbol in group] + labels = ' '.join(labels) + self.tuple_data.append(SymbolGroup(labels)) + + def get_voice_count(self): + """Returns the number of voices in the label group (count of groups in tuple group) + + Returns: + int: number of voices in the label group. + """ + if not self.type == SymbolGroupType.TUPLE: + return 1 + return len(self.tuple_data) + + def get_groups_to_add(self): + """Returns list of symbol groups. Either self in list of symbol groups in tuple data.""" + if self.type == SymbolGroupType.TUPLE: + groups_to_add = self.tuple_data.copy() + groups_to_add.reverse() + # return groups_to_add.reverse() + else: + groups_to_add = [self] + # return [self] + + logging.debug(f'groups_to_add:') + for group in groups_to_add: + logging.debug(f'{group}') + + return groups_to_add + + def get_clef(self) -> Optional[music.clef.Clef]: + if self.type == SymbolGroupType.SYMBOL and self.symbols[0].type == SymbolType.CLEF: + return self.symbols[0].repr + return None + + +class Voice: + """Internal representation of voice (list of symbol groups symbolizing one musical line).""" + length: float = 0.0 # Accumulated length of symbol groups (in quarter notes). + symbol_groups: List = [] + repr = None + + def __init__(self): + self.length = 0.0 + self.symbol_groups = [] + self.repr = None + + def __str__(self): + out = [] + for group in self.symbol_groups: + out.append(str(group)) + out = '\n'.join(out) + + return (f'\tVOICE BEGIN:\n' + f'{out}\n' + f'\tVOICE END') + + def add_symbol_group(self, symbol_group: SymbolGroup) -> None: + if symbol_group.type == SymbolGroupType.TUPLE: + logging.warning(f'Can NOT add symbol group of type TUPLE to a voice.') + return + self.symbol_groups.append(symbol_group) + self.length += symbol_group.length + self.repr = None + + def encode_to_music21_monophonic(self) -> Optional[music.stream.Voice]: + """Encodes the voice to music21 format. + + Returns: + music.Voice: music21 representation of the voice. + """ + if self.repr is not None: + return self.repr + + if len(self.symbol_groups) == 0: + return music.stream.Voice() + + self.repr = music.stream.Voice() + for group in self.symbol_groups: + encoded_group = group.encode_to_music21_monophonic() + if encoded_group is not None: + self.repr.append(encoded_group) + return self.repr + + def add_padding(self, padding_length: float) -> None: + """Add padding symbols (rests) to the voice until it reaches the desired length. + + Args: + padding_length (float): desired length of the padding in quarter notes. + """ + lengths = np.array(list(LENGTH_TO_SYMBOL.keys())) + min_length = lengths.min() + + while padding_length > 0: + if padding_length in LENGTH_TO_SYMBOL: + length_label = LENGTH_TO_SYMBOL[padding_length] + logging.debug(f'Completing padding with padding length {padding_length} to the voice.') + self.add_symbol_group(SymbolGroup(f'rest-{length_label}')) + padding_length -= padding_length + elif padding_length < min_length: + logging.error(f'Padding length {padding_length} is smaller than the minimum length {min}, breaking.') + break + else: + # Step is the biggest number lower than desired padding length. + step = lengths[lengths < padding_length].max() + logging.debug(f'Adding padding STEP {step} to the voice.') + + length_label = LENGTH_TO_SYMBOL[step] + self.add_symbol_group(SymbolGroup(f'rest-{length_label}')) + padding_length -= step diff --git a/pero_ocr/music/music_symbols.py b/pero_ocr/music/music_symbols.py new file mode 100644 index 0000000..84118d5 --- /dev/null +++ b/pero_ocr/music/music_symbols.py @@ -0,0 +1,389 @@ +#!/usr/bin/python3.10 +"""Script containing classes for internal representation of Semantic labels. +Symbol class is default symbol for parsing and returning music21 representation. +Other classes are internal representations of different symbols. + +Primarily used for compatibility with music21 library. + +Author: Vojtěch Vlach +Contact: xvlach22@vutbr.cz +""" + +from __future__ import annotations +import logging +from enum import Enum +import re +from typing import Optional + +import music21 as music + + +class SymbolType(Enum): + CLEF = 0 + GRACENOTE = 1 + KEY_SIGNATURE = 2 + MULTI_REST = 3 + NOTE = 4 + REST = 5 + TIE = 6 + TIME_SIGNATURE = 7 + UNKNOWN = 99 + + +class Symbol: + """Represents one label in a label group.""" + + def __init__(self, label: str): + self.label = label + self.type, self.repr = Symbol.label_to_symbol(label) + self.length = self.get_length() + + def __str__(self): + return f'\t\t\t({self.type}) {self.repr}' + + def get_length(self) -> float: + """Returns the length of the symbol in quarter notes. + + (half note: 2 quarter notes, eighth note: 0.5 quarter notes, ...) + If the symbol does not have musical length, returns 0. + """ + if self.type in [SymbolType.REST, SymbolType.NOTE, SymbolType.GRACENOTE]: + return self.repr.duration.quarterLength + else: + return 0 + + def set_key(self, altered_pitches: AlteredPitches): + if self.type in [SymbolType.NOTE, SymbolType.GRACENOTE]: + self.repr = self.repr.get_real_height(altered_pitches) + + @staticmethod + def label_to_symbol(label: str): # -> (SymbolType, music.object): + """Converts one label to music21 format. + + Args: + label (str): one symbol in semantic format as string + """ + if label.startswith("clef-"): + label = label[len('clef-'):] + return SymbolType.CLEF, Symbol.clef_to_symbol(label) + elif label.startswith("gracenote-"): + label = label[len('gracenote-'):] + return SymbolType.GRACENOTE, Symbol.note_to_symbol(label, gracenote=True) + elif label.startswith("keySignature-"): + label = label[len('keySignature-'):] + return SymbolType.KEY_SIGNATURE, Symbol.keysignature_to_symbol(label) + elif label.startswith("multirest-"): + label = label[len('multirest-'):] + return SymbolType.MULTI_REST, Symbol.multirest_to_symbol(label) + elif label.startswith("note-"): + label = label[len('note-'):] + return SymbolType.NOTE, Symbol.note_to_symbol(label) + elif label.startswith("rest-"): + label = label[len('rest-'):] + return SymbolType.REST, Symbol.rest_to_symbol(label) + elif label.startswith("tie"): + label = label[len('tie'):] + return SymbolType.TIE, Symbol.tie_to_symbol(label) + elif label.startswith("timeSignature-"): + label = label[len('timeSignature-'):] + return SymbolType.TIME_SIGNATURE, Symbol.timesignature_to_symbol(label) + + logging.info(f'Unknown label: {label}, returning None.') + return SymbolType.UNKNOWN, None + + @staticmethod + def clef_to_symbol(clef) -> music.clef: + """Converts one clef label to music21 format. + + Args: + clef (str): one symbol in semantic format as string + + Returns: + music.clef: one clef in music21 format + """ + if len(clef) != 2: + logging.info(f'Unknown clef label: {clef}, returning default clef.') + return music.clef.Clef() + + return music.clef.clefFromString(clef) + + @staticmethod + def keysignature_to_symbol(keysignature) -> music.key.Key: + """Converts one key signature label to music21 format. + + Args: + keysignature (str): one symbol in semantic format as string + + Returns: + music.key.Key: one key in music21 format + """ + if not keysignature: + logging.info(f'Unknown key signature label: {keysignature}, returning default key.') + return music.key.Key() + + return music.key.Key(keysignature) + + @staticmethod + def multirest_to_symbol(multirest: str) -> music.note.Rest: + """Converts one multi rest label to internal MultiRest format. + + Args: + multirest (str): one symbol in semantic format as string + + Returns: + music.note.Rest: one rest in music21 format + """ + rest = music.note.Rest() + rest.duration = label_to_length('whole') # default duration, because multirest is not implemented in music21 + return rest + + @staticmethod + def note_to_symbol(note, gracenote: bool = False) -> Note: + """Converts one note label to internal note format. + + Args: + note (str): one symbol in semantic format as string + gracenote (bool, optional): if True, returns grace note. Defaults to False. + + Returns: + music.note.Note: one note in music21 format + """ + def return_default_note() -> music.note.Note: + logging.info(f'Unknown note label: {note}, returning default note.') + return Note(music.duration.Duration(1), 'C4', gracenote=gracenote) + + if not note: + return_default_note() + + note, fermata = Symbol.check_fermata(note) + + note_height, note_length = re.split('_', note, maxsplit=1) + + if not note_length or not note_height: + return_default_note() + + return Note(label_to_length(note_length), + note_height, fermata=fermata, gracenote=gracenote) + + @staticmethod + def rest_to_symbol(rest) -> music.note.Rest: + """Converts one rest label to music21 format. + + Args: + rest (str): one symbol in semantic format as string + + Returns: + music.note.Rest: one rest in music21 format + """ + if not rest: + logging.info(f'Unknown rest label: {rest}, returning default rest.') + return music.note.Rest() + + rest, fermata = Symbol.check_fermata(rest) + + duration = label_to_length(rest) + + rest = music.note.Rest() + rest.duration = duration + if fermata is not None: + rest.expressions.append(fermata) + + return rest + + @staticmethod + def tie_to_symbol(label): + return Tie + + @staticmethod + def timesignature_to_symbol(timesignature) -> music.meter.TimeSignature: + """Converts one time signature label to music21 format. + + Args: + timesignature (str): one symbol in semantic format as string + + Returns: + music.meter.TimeSignature: one time signature in music21 format + """ + if not timesignature: + logging.info(f'Unknown time signature label: {timesignature}, returning default time signature.') + return music.meter.TimeSignature() + + if timesignature == 'C/': + return music.meter.TimeSignature('cut') + else: + return music.meter.TimeSignature(timesignature) + + @staticmethod + def check_fermata(label: str) -> (str, music.expressions.Fermata): + """Check if note has fermata. + + Args: + label (str): one symbol in semantic format as string + + Returns: + str: note without fermata + music.expressions.Fermata: fermata + """ + fermata = None + if label.endswith('_fermata'): + label = label[:-len('_fermata')] + fermata = music.expressions.Fermata() + fermata.type = 'upright' + + return label, fermata + + +class Note: + """Represents one note in a label group. + + In the order which semantic labels are represented, the real height of note depends on key signature + of current measure. This class is used as an internal representation of a note before knowing its real height. + Real height is then stored directly in `self.note` as music.note.Note object. + """ + + def __init__(self, duration: music.duration.Duration, height: str, + fermata: music.expressions.Fermata = None, gracenote: bool = False): + self.duration = duration + self.height = height + self.fermata = fermata + self.note = None + self.gracenote = gracenote + self.note_ready = False + + def get_real_height(self, altered_pitches: AlteredPitches) -> Optional[music.note.Note]: + """Returns the real height of the note. + + Args: + key signature of current measure + Returns: + Final music.note.Note object representing the real height and other info. + """ + if self.note_ready: + return self.note + + # pitches = [pitch.name[0] for pitch in key.alteredPitches] + + if not self.height[1:-1]: + # Note has no accidental on its own and takes accidental of the altered pitches. + note_str = self.height[0] + altered_pitches[self.height[0]] + self.height[-1] + self.note = music.note.Note(note_str, duration=self.duration) + else: + # Note has accidental which directly tells real note height. + note_str = self.height[0] + self.height[1:-1].replace('b', '-') + self.height[-1] + self.note = music.note.Note(note_str, duration=self.duration) + # Note sets new altered pitch for future notes. + altered_pitches[self.height[0]] = note_str[1:-1] + + if self.gracenote: + self.note = self.note.getGrace() + self.note_ready = True + return self.note + + def __str__(self): + return f'note {self.height} {self.duration}' + + +class MultiRest: + """Represents one multi rest in a label group.""" + + def __init__(self, duration: int = 0): + self.duration = duration + + +class Tie: + """Represents one tie in a label group.""" + + def __str__(self): + return 'tie' + + +class AlteredPitches: + def __init__(self, key: music.key.Key): + self.key = key + self.alteredPitches = {} + for pitch in self.key.alteredPitches: + self.alteredPitches[pitch.name[0]] = pitch.name[1] + + def __repr__(self): + return str(self.alteredPitches) + + def __str__(self): + return str(self.alteredPitches) + + def __getitem__(self, pitch_name: str): + """Gets name of pitch (e.g. 'C', 'G', ...) and returns its alternation.""" + if pitch_name not in self.alteredPitches: + return '' + return self.alteredPitches[pitch_name] + + def __setitem__(self, pitch_name: str, direction: str): + """Sets item. + + Args: + pitch_name (str): name of pitch (e.g. 'C', 'G',...) + direction (str): pitch alternation sign (#, ##, b, bb, 0, N) + """ + if not direction: + return + elif direction in ['0', 'N']: + if pitch_name in self.alteredPitches: + del self.alteredPitches[pitch_name] + # del self.alteredPitches[pitch_name] + return + else: + self.alteredPitches[pitch_name] = direction + + +SYMBOL_TO_LENGTH = { + 'hundred_twenty_eighth': 0.03125, + 'hundred_twenty_eighth.': 0.046875, + 'hundred_twenty_eighth..': 0.0546875, + 'sixty_fourth': 0.0625, + 'sixty_fourth.': 0.09375, + 'sixty_fourth..': 0.109375, + 'thirty_second': 0.125, + 'thirty_second.': 0.1875, + 'thirty_second..': 0.21875, + 'sixteenth': 0.25, + 'sixteenth.': 0.375, + 'sixteenth..': 0.4375, + 'eighth': 0.5, + 'eighth.': 0.75, + 'eighth..': 0.875, + 'quarter': 1.0, + 'quarter.': 1.5, + 'quarter..': 1.75, + 'half': 2.0, + 'half.': 3.0, + 'half..': 3.5, + 'whole': 4.0, + 'whole.': 6.0, + 'whole..': 7.0, + 'breve': 8.0, + 'breve.': 10.0, + 'breve..': 11.0, + 'double_whole': 8.0, + 'double_whole.': 12.0, + 'double_whole..': 14.0, + 'quadruple_whole': 16.0, + 'quadruple_whole.': 24.0, + 'quadruple_whole..': 28.0 +} + +LENGTH_TO_SYMBOL = {v: k for k, v in SYMBOL_TO_LENGTH.items()} # reverse dictionary + + +def label_to_length(length: str) -> music.duration.Duration: + """Return length of label as music21 duration. + + Args: + length (str): only length part of one label in semantic format as string + + Returns: + music.duration.Duration: one duration in music21 format + """ + if length in SYMBOL_TO_LENGTH: + return music.duration.Duration(SYMBOL_TO_LENGTH[length]) + else: + logging.info(f'Unknown duration label: {length}, returning default duration.') + return music.duration.Duration(1) diff --git a/pero_ocr/music/output_translator.py b/pero_ocr/music/output_translator.py new file mode 100644 index 0000000..f96a37f --- /dev/null +++ b/pero_ocr/music/output_translator.py @@ -0,0 +1,83 @@ + +from typing import Union +import re +import logging +import json +import os +from typing import Optional + +logger = logging.getLogger(__name__) + + +class OutputTranslator: + """Class for translating output from shorter form to longer form using simple dictionary. + + Used for example in Optical Music Recognition to translate shorter SSemantic encoding to Semantic encoding.""" + def __init__(self, dictionary: dict = None, filename: str = None, atomic: bool = False): + self.dictionary = self.load_dictionary(dictionary, filename) + self.dictionary_reversed = {v: k for k, v in self.dictionary.items()} + self.n_existing_labels = set() + + # ensures atomicity on line level (if one symbol is not found, return None and let caller handle it) + self.atomic = atomic + + def __call__(self, inputs: Union[str, list], reverse: bool = False) -> Union[str, list, None]: + if isinstance(inputs, list): + if len(inputs[0]) > 1: # list of strings (lines) + return self.translate_lines(inputs, reverse) + else: # list of chars (one line total) + return self.translate_line(''.join(inputs), reverse) + elif isinstance(inputs, str): # one line + return self.translate_line(inputs, reverse) + else: + raise ValueError(f'OutputTranslator: Unsupported input type: {type(inputs)}') + + def translate_lines(self, lines: list, reverse: bool = False) -> list: + return [self.translate_line(line, reverse) for line in lines] + + def translate_line(self, line, reverse: bool = False): + line_stripped = line.replace('"', ' ').strip() + symbols = re.split(r'\s+', line_stripped) + + converted_symbols = [] + for symbol in symbols: + translation = self.translate_symbol(symbol, reverse) + if translation is None: + if self.atomic: + return None # return None and let caller handle it (e.g. by storing the original line or breaking) + converted_symbols.append(symbol) + else: + converted_symbols.append(translation) + + return ' '.join(converted_symbols) + + def translate_symbol(self, symbol: str, reverse: bool = False) -> Optional[str]: + dictionary = self.dictionary_reversed if reverse else self.dictionary + + translation = dictionary.get(symbol, None) + if translation is not None: + return translation + + if symbol not in self.n_existing_labels: + logger.debug(f'Not existing label: ({symbol})') + self.n_existing_labels.add(symbol) + + return None + + @staticmethod + def load_dictionary(dictionary: dict = None, filename: str = None) -> dict: + if dictionary is not None: + return dictionary + elif filename is not None: + return OutputTranslator.read_json(filename) + else: + raise ValueError('OutputTranslator: Either dictionary or filename must be provided.') + + @staticmethod + def read_json(filename) -> dict: + if not os.path.isfile(filename): + raise FileNotFoundError(f'Translator file ({filename}) not found. Cannot translate output.') + + with open(filename) as f: + data = json.load(f) + return data diff --git a/pero_ocr/ocr_engine/line_ocr_engine.py b/pero_ocr/ocr_engine/line_ocr_engine.py index 7cb393b..6e03d0c 100644 --- a/pero_ocr/ocr_engine/line_ocr_engine.py +++ b/pero_ocr/ocr_engine/line_ocr_engine.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function -import argparse import json -import cv2 import numpy as np from os.path import isabs, realpath, join, dirname from scipy import sparse @@ -12,10 +10,11 @@ from .softmax import softmax from pero_ocr.sequence_alignment import levenshtein_distance +from pero_ocr.music.output_translator import OutputTranslator class BaseEngineLineOCR(object): - def __init__(self, json_def, device, batch_size=32, model_type="ctc"): + def __init__(self, json_def, device, batch_size=32, model_type="ctc", substitute_output_atomic: bool = True): with open(json_def, 'r', encoding='utf8') as f: self.config = json.load(f) @@ -28,6 +27,12 @@ def __init__(self, json_def, device, batch_size=32, model_type="ctc"): self.checkpoint = realpath(join(dirname(json_def), self.config['checkpoint'])) self.characters = tuple(self.config['characters']) + + self.output_substitution = None + if 'output_substitution_table' in self.config: + self.output_substitution = OutputTranslator(dictionary=self.config['output_substitution_table'], + atomic=substitute_output_atomic) + self.net_name = self.config['net_name'] if "embed_num" in self.config: self.embed_num = int(self.config["embed_num"]) diff --git a/pero_ocr/ocr_engine/pytorch_ocr_engine.py b/pero_ocr/ocr_engine/pytorch_ocr_engine.py index bffaea9..fc37bf1 100644 --- a/pero_ocr/ocr_engine/pytorch_ocr_engine.py +++ b/pero_ocr/ocr_engine/pytorch_ocr_engine.py @@ -35,8 +35,9 @@ def greedy_decode_ctc(scores_probs, chars): class PytorchEngineLineOCR(BaseEngineLineOCR): - def __init__(self, json_def, device, batch_size=8): - super(PytorchEngineLineOCR, self).__init__(json_def, device, batch_size=batch_size) + def __init__(self, json_def, device, batch_size=8, substitute_output_atomic: bool = True): + super(PytorchEngineLineOCR, self).__init__(json_def, device, batch_size=batch_size, + substitute_output_atomic=substitute_output_atomic) self.net_subsampling = 4 self.characters = list(self.characters) + [u'\u200B'] diff --git a/pero_ocr/ocr_engine/transformer_ocr_engine.py b/pero_ocr/ocr_engine/transformer_ocr_engine.py index 3c9d0a4..86ad1ff 100644 --- a/pero_ocr/ocr_engine/transformer_ocr_engine.py +++ b/pero_ocr/ocr_engine/transformer_ocr_engine.py @@ -6,12 +6,12 @@ from .line_ocr_engine import BaseEngineLineOCR from pero_ocr.ocr_engine import transformer -import sys - class TransformerEngineLineOCR(BaseEngineLineOCR): - def __init__(self, json_def, device, batch_size=16): - super(TransformerEngineLineOCR, self).__init__(json_def, device, batch_size=batch_size, model_type="transformer") + def __init__(self, json_def, device, batch_size=16, substitute_output_atomic: bool = True): + super(TransformerEngineLineOCR, self).__init__(json_def, device, batch_size=batch_size, + model_type="transformer", + substitute_output_atomic=substitute_output_atomic) self.characters = list(self.characters) + [u'\u200B', ''] @@ -25,7 +25,7 @@ def __init__(self, json_def, device, batch_size=16): print(self.net) - self.net.load_state_dict(torch.load(self.checkpoint)) + self.net.load_state_dict(torch.load(self.checkpoint, map_location=device)) self.net.eval() self.net = self.net.to(device) self.max_decoded_seq_length = 210 diff --git a/pero_ocr/utils.py b/pero_ocr/utils.py index f25c700..47b856d 100644 --- a/pero_ocr/utils.py +++ b/pero_ocr/utils.py @@ -2,6 +2,9 @@ import sys import logging import subprocess +import json + +logger = logging.getLogger(__name__) try: subprocess.check_output( @@ -22,3 +25,19 @@ def compose_path(file_path, reference_path): if reference_path and not isabs(file_path): file_path = join(reference_path, file_path) return file_path + + +def config_get_list(config, key, fallback=None): + """Get list from config.""" + fallback = fallback if fallback is not None else [] + + if key not in config: + return fallback + + try: + value = json.loads(config[key]) + except json.decoder.JSONDecodeError as e: + logger.info(f'Failed to parse list from config key "{key}", returning fallback {fallback}:\n{e}') + return fallback + else: + return value diff --git a/pyproject.toml b/pyproject.toml index f025360..c0dd373 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,8 @@ dependencies = [ "pyamg", "imgaug", "arabic_reshaper", + "ultralytics", + "music21" ] diff --git a/test/processing_test.sh b/test/processing_test.sh old mode 100644 new mode 100755 index f9712e8..e04df78 --- a/test/processing_test.sh +++ b/test/processing_test.sh @@ -16,12 +16,12 @@ print_help() { echo "$ sh processing_test.sh -i in_dir -o out_dir -c engine_dir/config.ini" echo "Options:" echo " -i|--input-images Input directory with test images." - echo " -x|--input-xmls Input directory with xml files." echo " -o|--output-dir Output directory for results." echo " -c|--configuration Configuration file for ocr." - echo " -e|--example Example outputs for comparison." - echo " -u|--test-utility Path to test utility." - echo " -t|--test-output Test utility output folder." + echo " -x|--input-xmls Input directory with xml files. (optional)" + echo " -e|--example Example outputs for comparison. (optional)" + echo " -u|--test-utility Path to test utility. (optional)" + echo " -t|--test-output Test utility output folder. (optional)" echo " -g|--gpu-ids Ids of GPU to use for ocr processing. (default=all)" echo " -h|--help Shows this help message." } diff --git a/user_scripts/export_music.py b/user_scripts/export_music.py new file mode 100644 index 0000000..5a2f4a4 --- /dev/null +++ b/user_scripts/export_music.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3.8 +# -*- coding: utf-8 -*- +"""Script to take output of pero-ocr with musical transcriptions and export it to musicxml and MIDI formats. + +INPUTS: +- PageLayout + - INPUT options: + - PageLayout object using `ExportMusicPage.__call__()` method + - XML PageLayout (exported directly from pero-ocr engine) using `--input-xml-path` argument + - Represents one whole page of musical notation transcribed by pero-ocr engine + - OUTPUT options: + - One musicxml file for the page + - MIDI file for page and for individual lines (named according to IDs in PageLayout) +- Text files with individual transcriptions and their IDs on each line using `--input-transcription-files` argument. + - e.g.: 2370961.png ">2 + kGM + E2W E3q. + |" + 1300435.png "=4 + kDM + G3z + F3z + |" + ... + - OUTPUTS one musicxml file for each line with names corresponding to IDs in each line + +Author: Vojtěch Vlach +Contact: xvlach22@vutbr.cz +""" + +import sys +import argparse +import time + +from pero_ocr.music.music_exporter import MusicPageExporter + + +def parseargs(): + print(' '.join(sys.argv)) + print('----------------------------------------------------------------------') + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", "--input-xml-path", type=str, default='', + help="Path to input XML file with exported PageLayout.") + parser.add_argument( + '-f', '--input-transcription-files', nargs='*', default=None, + help='Input files with sequences as lines with IDs at the beginning.') + parser.add_argument( + "-t", "--translator-path", type=str, default=None, + help="JSON File containing translation dictionary from shorter encoding (exported by model) to longest " + "Check if needed by seeing start of any line in the transcription." + "(e.g. SSemantic (model output): >2 + kGM + B3z + C4z + |..." + " Semantic (stored in XML): clef-G2 + keySignature-GM + note-B3_eighth + note-C4_eighth + barline...") + parser.add_argument( + "-o", "--output-folder", default='output_page', + help="Set output file with extension. Output format is JSON") + parser.add_argument( + "-m", "--export-midi", action='store_true', + help=("Enable exporting midi file to output_folder." + "Exports whole file and individual lines with names corresponding to them TextLine IDs.")) + parser.add_argument( + "-M", "--export-musicxml", action='store_true', + help=("Enable exporting musicxml file to output_folder." + "Exports whole file as one MusicXML.")) + parser.add_argument( + '-v', "--verbose", action='store_true', default=False, + help="Enable verbose logging.") + + return parser.parse_args() + + +def main(): + """Main function for simple testing""" + args = parseargs() + + start = time.time() + MusicPageExporter( + input_xml_path=args.input_xml_path, + input_transcription_files=args.input_transcription_files, + translator_path=args.translator_path, + output_folder=args.output_folder, + export_midi=args.export_midi, + export_musicxml=args.export_musicxml, + verbose=args.verbose)() + + end = time.time() + print(f'Total time: {end - start:.2f} s') + + +if __name__ == "__main__": + main() diff --git a/user_scripts/parse_folder.py b/user_scripts/parse_folder.py index 194c8bc..36cca8b 100644 --- a/user_scripts/parse_folder.py +++ b/user_scripts/parse_folder.py @@ -33,6 +33,8 @@ def parse_arguments(): parser.add_argument('--input-logit-path', help='') parser.add_argument('--output-xml-path', help='') parser.add_argument('--output-render-path', help='') + parser.add_argument('--output-render-category', default=False, action='store_true', + help='Render category tags for every non-text region.') parser.add_argument('--output-line-path', help='') parser.add_argument('--output-logit-path', help='') parser.add_argument('--output-alto-path', help='') @@ -57,9 +59,12 @@ def setup_logging(config): logger.setLevel(level) -def get_value_or_none(config, section, key): +def get_value_or_none(config, section, key, getboolean: bool = False): if config.has_option(section, key): - value = config[section][key] + if getboolean: + value = config.getboolean(section, key) + else: + value = config[section][key] else: value = None return value @@ -139,12 +144,13 @@ def __call__(self, page_layout: PageLayout, file_id): class Computator: def __init__(self, page_parser, input_image_path, input_xml_path, input_logit_path, output_render_path, - output_logit_path, output_alto_path, output_xml_path, output_line_path): + output_render_category, output_logit_path, output_alto_path, output_xml_path, output_line_path): self.page_parser = page_parser self.input_image_path = input_image_path self.input_xml_path = input_xml_path self.input_logit_path = input_logit_path self.output_render_path = output_render_path + self.output_render_category = output_render_category self.output_logit_path = output_logit_path self.output_alto_path = output_alto_path self.output_xml_path = output_xml_path @@ -177,8 +183,9 @@ def __call__(self, image_file_name, file_id, index, ids_count): os.path.join(self.output_xml_path, file_id + '.xml')) if self.output_render_path is not None: - page_layout.render_to_image(image) - cv2.imwrite(os.path.join(self.output_render_path, file_id + '.jpg'), image, [int(cv2.IMWRITE_JPEG_QUALITY), 70]) + page_layout.render_to_image(image, render_category=self.output_render_category) + render_file = str(os.path.join(self.output_render_path, file_id + '.jpg')) + cv2.imwrite(render_file, image, [int(cv2.IMWRITE_JPEG_QUALITY), 70]) if self.output_logit_path is not None: page_layout.save_logits(os.path.join(self.output_logit_path, file_id + '.logits')) @@ -247,6 +254,8 @@ def main(): config['PARSE_FOLDER']['OUTPUT_XML_PATH'] = args.output_xml_path if args.output_render_path is not None: config['PARSE_FOLDER']['OUTPUT_RENDER_PATH'] = args.output_render_path + if args.output_render_category is not None: + config['PARSE_FOLDER']['OUTPUT_RENDER_CATEGORY'] = 'yes' if args.output_render_category else 'no' if args.output_line_path is not None: config['PARSE_FOLDER']['OUTPUT_LINE_PATH'] = args.output_line_path if args.output_logit_path is not None: @@ -266,6 +275,7 @@ def main(): input_logit_path = get_value_or_none(config, 'PARSE_FOLDER', 'INPUT_LOGIT_PATH') output_render_path = get_value_or_none(config, 'PARSE_FOLDER', 'OUTPUT_RENDER_PATH') + output_render_category = get_value_or_none(config, 'PARSE_FOLDER', 'OUTPUT_RENDER_CATEGORY', True) output_line_path = get_value_or_none(config, 'PARSE_FOLDER', 'OUTPUT_LINE_PATH') output_xml_path = get_value_or_none(config, 'PARSE_FOLDER', 'OUTPUT_XML_PATH') output_logit_path = get_value_or_none(config, 'PARSE_FOLDER', 'OUTPUT_LOGIT_PATH') @@ -334,7 +344,8 @@ def main(): images_to_process = filtered_images_to_process computator = Computator(page_parser, input_image_path, input_xml_path, input_logit_path, output_render_path, - output_logit_path, output_alto_path, output_xml_path, output_line_path) + output_render_category, output_logit_path, output_alto_path, output_xml_path, + output_line_path) t_start = time.time() results = []