From 86efafa9853c63a7610bbba71f9ac333a685d549 Mon Sep 17 00:00:00 2001 From: Bourn3z Date: Mon, 22 Jan 2024 11:23:07 +0800 Subject: [PATCH] Remove unnecessary preprocessing steps of VI-LayoutXLM to improve performance. --- .../ser_layoutxlm_xfund_zh.yaml | 4 +-- .../re_vi_layoutxlm_xfund_zh.yaml | 34 +++--------------- .../ser_vi_layoutxlm_xfund_zh.yaml | 36 ++++--------------- .../data/transforms/layoutlm_transforms.py | 9 +++-- mindocr/losses/kie_loss.py | 4 +-- 5 files changed, 23 insertions(+), 64 deletions(-) diff --git a/configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yaml b/configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yaml index e371bcf49..ba0624363 100644 --- a/configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yaml +++ b/configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yaml @@ -80,7 +80,7 @@ train: # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize output_columns: [ 'input_ids', 'bbox','attention_mask','token_type_ids', 'image', 'labels' ] net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns - label_column_index: [ 2, 5 ] # input indices marked as label + label_column_index: [ 5 ] # input indices marked as label loader: shuffle: True @@ -122,7 +122,7 @@ eval: # the order of the dataloader list, matching the network input and the labels for evaluation output_columns: [ 'input_ids', 'bbox', 'attention_mask','token_type_ids','image', 'labels' ] net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns - label_column_index: [ 2, 5 ] # input indices marked as label + label_column_index: [ 5 ] # input indices marked as label loader: shuffle: False diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml index 2fae9ec62..fe6b187f8 100644 --- a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml +++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml @@ -59,12 +59,9 @@ train: label_file: XFUND/zh_train/train.json sample_ratio: 1.0 transform_pipeline: - - DecodeImage: - img_mode: RGB - to_float32: False - VQATokenLabelEncode: contains_re: True - algorithm: &algorithm LayoutXLM + algorithm: &algorithm VI-LayoutXLM class_path: *class_path order_method: tb-yx - VQATokenPad: @@ -75,14 +72,6 @@ train: max_seq_len: *max_seq_len - TensorizeEntitiesRelations: max_relation_len: 5000 - - LayoutResize: - size: [224, 224] - - NormalizeImage: - bgr_to_rgb: False - is_hwc: True - mean: imagenet - std: imagenet - - ToCHWImage: # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize output_columns: [ @@ -90,15 +79,14 @@ train: "bbox", "attention_mask", "token_type_ids", - "image", "question", "question_label", "answer", "answer_label", "relation_label", ] - net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns - label_column_index: [9] # input indices marked as label + net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7] # input indices for network forward func in output_columns + label_column_index: [8] # input indices marked as label loader: shuffle: True @@ -117,9 +105,6 @@ eval: sample_ratio: 1.0 shuffle: False transform_pipeline: - - DecodeImage: - img_mode: RGB - to_float32: False - VQATokenLabelEncode: contains_re: True algorithm: *algorithm @@ -133,14 +118,6 @@ eval: max_seq_len: *max_seq_len - TensorizeEntitiesRelations: max_relation_len: 5000 - - LayoutResize: - size: [224, 224] - - NormalizeImage: - bgr_to_rgb: False - is_hwc: True - mean: imagenet - std: imagenet - - ToCHWImage: # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize output_columns: [ @@ -148,15 +125,14 @@ eval: "bbox", "attention_mask", "token_type_ids", - "image", "question", "question_label", "answer", "answer_label", "relation_label", ] - net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns - label_column_index: [9] # input indices marked as label + net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7] # input indices for network forward func in output_columns + label_column_index: [8] # input indices marked as label loader: shuffle: False diff --git a/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml b/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml index 3682fc45c..214f70667 100644 --- a/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml +++ b/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml @@ -57,12 +57,9 @@ train: label_file: XFUND/zh_train/train.json sample_ratio: 1.0 transform_pipeline: - - DecodeImage: - img_mode: RGB - to_float32: False - VQATokenLabelEncode: contains_re: False - algorithm: &algorithm LayoutXLM + algorithm: &algorithm VI-LayoutXLM class_path: *class_path order_method: tb-yx - VQATokenPad: @@ -70,18 +67,10 @@ train: return_attention_mask: True - VQASerTokenChunk: max_seq_len: *max_seq_len - - LayoutResize: - size: [ 224, 224 ] - - NormalizeImage: - bgr_to_rgb: False - is_hwc: True - mean: imagenet - std: imagenet - - ToCHWImage: # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize - output_columns: [ 'input_ids', 'bbox','attention_mask','token_type_ids', 'image', 'labels' ] - net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns - label_column_index: [ 2, 5 ] # input indices marked as label + output_columns: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'labels' ] + net_input_column_index: [ 0, 1, 2, 3 ] # input indices for network forward func in output_columns + label_column_index: [ 4 ] # input indices marked as label loader: shuffle: True @@ -100,9 +89,6 @@ eval: sample_ratio: 1.0 shuffle: False transform_pipeline: - - DecodeImage: - img_mode: RGB - to_float32: False - VQATokenLabelEncode: contains_re: False algorithm: *algorithm @@ -113,18 +99,10 @@ eval: return_attention_mask: True - VQASerTokenChunk: max_seq_len: *max_seq_len - - LayoutResize: - size: [ 224, 224 ] - - NormalizeImage: - bgr_to_rgb: False - is_hwc: True - mean: imagenet - std: imagenet - - ToCHWImage: # the order of the dataloader list, matching the network input and the labels for evaluation - output_columns: [ 'input_ids', 'bbox', 'attention_mask','token_type_ids','image', 'labels' ] - net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns - label_column_index: [ 2, 5 ] # input indices marked as label + output_columns: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'labels' ] + net_input_column_index: [ 0, 1, 2, 3 ] # input indices for network forward func in output_columns + label_column_index: [ 4 ] # input indices marked as label loader: shuffle: False diff --git a/mindocr/data/transforms/layoutlm_transforms.py b/mindocr/data/transforms/layoutlm_transforms.py index 2766e05ad..f5e6636b5 100644 --- a/mindocr/data/transforms/layoutlm_transforms.py +++ b/mindocr/data/transforms/layoutlm_transforms.py @@ -4,6 +4,7 @@ import cv2 import numpy as np +from PIL import Image from mindspore import nn @@ -65,8 +66,10 @@ def __init__( super(VQATokenLabelEncode, self).__init__() tokenizer_dict = { "LayoutXLM": {"class": LayoutXLMTokenizer, "pretrained_model": "layoutxlm-base-uncased"}, + "VI-LayoutXLM": {"class": LayoutXLMTokenizer, "pretrained_model": "layoutxlm-base-uncased"}, } self.contains_re = contains_re + self.algorithm = algorithm tokenizer_config = tokenizer_dict[algorithm] self.tokenizer = tokenizer_config["class"].from_pretrained(tokenizer_config["pretrained_model"]) # to replace self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path) @@ -141,8 +144,10 @@ def __call__(self, data): train_re = self.contains_re and not self.infer_mode if train_re: ocr_info = self.filter_empty_contents(ocr_info) - - height, width, _ = data["image"].shape + if self.algorithm == "VI-LayoutXLM": + width, height = Image.open(data["img_path"]).size + elif self.algorithm == "LayoutXLM": + height, width, _ = data["image"].shape words_list = [] bbox_list = [] diff --git a/mindocr/losses/kie_loss.py b/mindocr/losses/kie_loss.py index 60fd7630f..2b8547516 100644 --- a/mindocr/losses/kie_loss.py +++ b/mindocr/losses/kie_loss.py @@ -13,7 +13,7 @@ def __init__(self, **kwargs): super().__init__() self.loss_fct = nn.CrossEntropyLoss() - def construct(self, predicts, attention_mask, labels): + def construct(self, predicts, labels): loss = self.loss_fct(predicts.transpose(0, 2, 1), labels.astype(ms.int32)) return loss @@ -28,6 +28,6 @@ def __init__(self, **kwargs): super().__init__() self.loss_fct = nn.CrossEntropyLoss() - def construct(self, predicts, attention_mask, labels): + def construct(self, predicts, labels): loss = self.loss_fct(predicts.transpose(0, 2, 1), labels.astype(ms.int32)) return loss