Skip to content

Commit

Permalink
Remove unnecessary preprocessing steps of VI-LayoutXLM to improve per…
Browse files Browse the repository at this point in the history
…formance.
  • Loading branch information
Bourn3z committed Apr 18, 2024
1 parent 5120a2a commit 86efafa
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 64 deletions.
4 changes: 2 additions & 2 deletions configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ train:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns: [ 'input_ids', 'bbox','attention_mask','token_type_ids', 'image', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns
label_column_index: [ 2, 5 ] # input indices marked as label
label_column_index: [ 5 ] # input indices marked as label

loader:
shuffle: True
Expand Down Expand Up @@ -122,7 +122,7 @@ eval:
# the order of the dataloader list, matching the network input and the labels for evaluation
output_columns: [ 'input_ids', 'bbox', 'attention_mask','token_type_ids','image', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns
label_column_index: [ 2, 5 ] # input indices marked as label
label_column_index: [ 5 ] # input indices marked as label

loader:
shuffle: False
Expand Down
34 changes: 5 additions & 29 deletions configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,9 @@ train:
label_file: XFUND/zh_train/train.json
sample_ratio: 1.0
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- VQATokenLabelEncode:
contains_re: True
algorithm: &algorithm LayoutXLM
algorithm: &algorithm VI-LayoutXLM
class_path: *class_path
order_method: tb-yx
- VQATokenPad:
Expand All @@ -75,30 +72,21 @@ train:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
max_relation_len: 5000
- LayoutResize:
size: [224, 224]
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns:
[
"input_ids",
"bbox",
"attention_mask",
"token_type_ids",
"image",
"question",
"question_label",
"answer",
"answer_label",
"relation_label",
]
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
label_column_index: [9] # input indices marked as label
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7] # input indices for network forward func in output_columns
label_column_index: [8] # input indices marked as label

loader:
shuffle: True
Expand All @@ -117,9 +105,6 @@ eval:
sample_ratio: 1.0
shuffle: False
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- VQATokenLabelEncode:
contains_re: True
algorithm: *algorithm
Expand All @@ -133,30 +118,21 @@ eval:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
max_relation_len: 5000
- LayoutResize:
size: [224, 224]
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns:
[
"input_ids",
"bbox",
"attention_mask",
"token_type_ids",
"image",
"question",
"question_label",
"answer",
"answer_label",
"relation_label",
]
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
label_column_index: [9] # input indices marked as label
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7] # input indices for network forward func in output_columns
label_column_index: [8] # input indices marked as label

loader:
shuffle: False
Expand Down
36 changes: 7 additions & 29 deletions configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,31 +57,20 @@ train:
label_file: XFUND/zh_train/train.json
sample_ratio: 1.0
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- VQATokenLabelEncode:
contains_re: False
algorithm: &algorithm LayoutXLM
algorithm: &algorithm VI-LayoutXLM
class_path: *class_path
order_method: tb-yx
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- LayoutResize:
size: [ 224, 224 ]
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns: [ 'input_ids', 'bbox','attention_mask','token_type_ids', 'image', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns
label_column_index: [ 2, 5 ] # input indices marked as label
output_columns: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3 ] # input indices for network forward func in output_columns
label_column_index: [ 4 ] # input indices marked as label

loader:
shuffle: True
Expand All @@ -100,9 +89,6 @@ eval:
sample_ratio: 1.0
shuffle: False
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- VQATokenLabelEncode:
contains_re: False
algorithm: *algorithm
Expand All @@ -113,18 +99,10 @@ eval:
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- LayoutResize:
size: [ 224, 224 ]
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the labels for evaluation
output_columns: [ 'input_ids', 'bbox', 'attention_mask','token_type_ids','image', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns
label_column_index: [ 2, 5 ] # input indices marked as label
output_columns: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3 ] # input indices for network forward func in output_columns
label_column_index: [ 4 ] # input indices marked as label

loader:
shuffle: False
Expand Down
9 changes: 7 additions & 2 deletions mindocr/data/transforms/layoutlm_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import cv2
import numpy as np
from PIL import Image

from mindspore import nn

Expand Down Expand Up @@ -65,8 +66,10 @@ def __init__(
super(VQATokenLabelEncode, self).__init__()
tokenizer_dict = {
"LayoutXLM": {"class": LayoutXLMTokenizer, "pretrained_model": "layoutxlm-base-uncased"},
"VI-LayoutXLM": {"class": LayoutXLMTokenizer, "pretrained_model": "layoutxlm-base-uncased"},
}
self.contains_re = contains_re
self.algorithm = algorithm
tokenizer_config = tokenizer_dict[algorithm]
self.tokenizer = tokenizer_config["class"].from_pretrained(tokenizer_config["pretrained_model"]) # to replace
self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
Expand Down Expand Up @@ -141,8 +144,10 @@ def __call__(self, data):
train_re = self.contains_re and not self.infer_mode
if train_re:
ocr_info = self.filter_empty_contents(ocr_info)

height, width, _ = data["image"].shape
if self.algorithm == "VI-LayoutXLM":
width, height = Image.open(data["img_path"]).size
elif self.algorithm == "LayoutXLM":
height, width, _ = data["image"].shape

words_list = []
bbox_list = []
Expand Down
4 changes: 2 additions & 2 deletions mindocr/losses/kie_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, **kwargs):
super().__init__()
self.loss_fct = nn.CrossEntropyLoss()

def construct(self, predicts, attention_mask, labels):
def construct(self, predicts, labels):
loss = self.loss_fct(predicts.transpose(0, 2, 1), labels.astype(ms.int32))
return loss

Expand All @@ -28,6 +28,6 @@ def __init__(self, **kwargs):
super().__init__()
self.loss_fct = nn.CrossEntropyLoss()

def construct(self, predicts, attention_mask, labels):
def construct(self, predicts, labels):
loss = self.loss_fct(predicts.transpose(0, 2, 1), labels.astype(ms.int32))
return loss

0 comments on commit 86efafa

Please sign in to comment.