Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary preprocessing steps of VI-LayoutXLM to improve performance. #657

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ train:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns: [ 'input_ids', 'bbox','attention_mask','token_type_ids', 'image', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns
label_column_index: [ 2, 5 ] # input indices marked as label
label_column_index: [ 5 ] # input indices marked as label

loader:
shuffle: True
Expand Down Expand Up @@ -122,7 +122,7 @@ eval:
# the order of the dataloader list, matching the network input and the labels for evaluation
output_columns: [ 'input_ids', 'bbox', 'attention_mask','token_type_ids','image', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns
label_column_index: [ 2, 5 ] # input indices marked as label
label_column_index: [ 5 ] # input indices marked as label

loader:
shuffle: False
Expand Down
34 changes: 5 additions & 29 deletions configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,9 @@ train:
label_file: XFUND/zh_train/train.json
sample_ratio: 1.0
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- VQATokenLabelEncode:
contains_re: True
algorithm: &algorithm LayoutXLM
algorithm: &algorithm VI-LayoutXLM
class_path: *class_path
order_method: tb-yx
- VQATokenPad:
Expand All @@ -75,30 +72,21 @@ train:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
max_relation_len: 5000
- LayoutResize:
size: [224, 224]
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns:
[
"input_ids",
"bbox",
"attention_mask",
"token_type_ids",
"image",
"question",
"question_label",
"answer",
"answer_label",
"relation_label",
]
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
label_column_index: [9] # input indices marked as label
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7] # input indices for network forward func in output_columns
label_column_index: [8] # input indices marked as label

loader:
shuffle: True
Expand All @@ -117,9 +105,6 @@ eval:
sample_ratio: 1.0
shuffle: False
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- VQATokenLabelEncode:
contains_re: True
algorithm: *algorithm
Expand All @@ -133,30 +118,21 @@ eval:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
max_relation_len: 5000
- LayoutResize:
size: [224, 224]
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns:
[
"input_ids",
"bbox",
"attention_mask",
"token_type_ids",
"image",
"question",
"question_label",
"answer",
"answer_label",
"relation_label",
]
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
label_column_index: [9] # input indices marked as label
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7] # input indices for network forward func in output_columns
label_column_index: [8] # input indices marked as label

loader:
shuffle: False
Expand Down
36 changes: 7 additions & 29 deletions configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,31 +57,20 @@ train:
label_file: XFUND/zh_train/train.json
sample_ratio: 1.0
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- VQATokenLabelEncode:
contains_re: False
algorithm: &algorithm LayoutXLM
algorithm: &algorithm VI-LayoutXLM
class_path: *class_path
order_method: tb-yx
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- LayoutResize:
size: [ 224, 224 ]
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns: [ 'input_ids', 'bbox','attention_mask','token_type_ids', 'image', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns
label_column_index: [ 2, 5 ] # input indices marked as label
output_columns: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3 ] # input indices for network forward func in output_columns
label_column_index: [ 4 ] # input indices marked as label

loader:
shuffle: True
Expand All @@ -100,9 +89,6 @@ eval:
sample_ratio: 1.0
shuffle: False
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- VQATokenLabelEncode:
contains_re: False
algorithm: *algorithm
Expand All @@ -113,18 +99,10 @@ eval:
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- LayoutResize:
size: [ 224, 224 ]
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the labels for evaluation
output_columns: [ 'input_ids', 'bbox', 'attention_mask','token_type_ids','image', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns
label_column_index: [ 2, 5 ] # input indices marked as label
output_columns: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3 ] # input indices for network forward func in output_columns
label_column_index: [ 4 ] # input indices marked as label

loader:
shuffle: False
Expand Down
9 changes: 7 additions & 2 deletions mindocr/data/transforms/layoutlm_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import cv2
import numpy as np
from PIL import Image

from mindspore import nn

Expand Down Expand Up @@ -65,8 +66,10 @@ def __init__(
super(VQATokenLabelEncode, self).__init__()
tokenizer_dict = {
"LayoutXLM": {"class": LayoutXLMTokenizer, "pretrained_model": "layoutxlm-base-uncased"},
"VI-LayoutXLM": {"class": LayoutXLMTokenizer, "pretrained_model": "layoutxlm-base-uncased"},
}
self.contains_re = contains_re
self.algorithm = algorithm
tokenizer_config = tokenizer_dict[algorithm]
self.tokenizer = tokenizer_config["class"].from_pretrained(tokenizer_config["pretrained_model"]) # to replace
self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
Expand Down Expand Up @@ -141,8 +144,10 @@ def __call__(self, data):
train_re = self.contains_re and not self.infer_mode
if train_re:
ocr_info = self.filter_empty_contents(ocr_info)

height, width, _ = data["image"].shape
if self.algorithm == "VI-LayoutXLM":
width, height = Image.open(data["img_path"]).size
elif self.algorithm == "LayoutXLM":
height, width, _ = data["image"].shape

words_list = []
bbox_list = []
Expand Down
4 changes: 2 additions & 2 deletions mindocr/losses/kie_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, **kwargs):
super().__init__()
self.loss_fct = nn.CrossEntropyLoss()

def construct(self, predicts, attention_mask, labels):
def construct(self, predicts, labels):
loss = self.loss_fct(predicts.transpose(0, 2, 1), labels.astype(ms.int32))
return loss

Expand All @@ -28,6 +28,6 @@ def __init__(self, **kwargs):
super().__init__()
self.loss_fct = nn.CrossEntropyLoss()

def construct(self, predicts, attention_mask, labels):
def construct(self, predicts, labels):
loss = self.loss_fct(predicts.transpose(0, 2, 1), labels.astype(ms.int32))
return loss