Skip to content

Commit

Permalink
Remove unnecessary preprocessing steps of VI-LayoutXLM to improve per…
Browse files Browse the repository at this point in the history
…formance.
  • Loading branch information
Bourn3z committed Jan 22, 2024
1 parent 5120a2a commit 0780699
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 64 deletions.
34 changes: 5 additions & 29 deletions configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,9 @@ train:
label_file: XFUND/zh_train/train.json
sample_ratio: 1.0
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- VQATokenLabelEncode:
contains_re: True
algorithm: &algorithm LayoutXLM
algorithm: &algorithm VI-LayoutXLM
class_path: *class_path
order_method: tb-yx
- VQATokenPad:
Expand All @@ -75,30 +72,21 @@ train:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
max_relation_len: 5000
- LayoutResize:
size: [224, 224]
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns:
[
"input_ids",
"bbox",
"attention_mask",
"token_type_ids",
"image",
"question",
"question_label",
"answer",
"answer_label",
"relation_label",
]
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
label_column_index: [9] # input indices marked as label
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7] # input indices for network forward func in output_columns
label_column_index: [8] # input indices marked as label

loader:
shuffle: True
Expand All @@ -117,9 +105,6 @@ eval:
sample_ratio: 1.0
shuffle: False
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- VQATokenLabelEncode:
contains_re: True
algorithm: *algorithm
Expand All @@ -133,30 +118,21 @@ eval:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
max_relation_len: 5000
- LayoutResize:
size: [224, 224]
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns:
[
"input_ids",
"bbox",
"attention_mask",
"token_type_ids",
"image",
"question",
"question_label",
"answer",
"answer_label",
"relation_label",
]
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
label_column_index: [9] # input indices marked as label
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7] # input indices for network forward func in output_columns
label_column_index: [8] # input indices marked as label

loader:
shuffle: False
Expand Down
36 changes: 7 additions & 29 deletions configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,31 +57,20 @@ train:
label_file: XFUND/zh_train/train.json
sample_ratio: 1.0
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- VQATokenLabelEncode:
contains_re: False
algorithm: &algorithm LayoutXLM
algorithm: &algorithm VI-LayoutXLM
class_path: *class_path
order_method: tb-yx
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- LayoutResize:
size: [ 224, 224 ]
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns: [ 'input_ids', 'bbox','attention_mask','token_type_ids', 'image', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns
label_column_index: [ 2, 5 ] # input indices marked as label
output_columns: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3 ] # input indices for network forward func in output_columns
label_column_index: [ 2, 4 ] # input indices marked as label

loader:
shuffle: True
Expand All @@ -100,9 +89,6 @@ eval:
sample_ratio: 1.0
shuffle: False
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- VQATokenLabelEncode:
contains_re: False
algorithm: *algorithm
Expand All @@ -113,18 +99,10 @@ eval:
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- LayoutResize:
size: [ 224, 224 ]
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the labels for evaluation
output_columns: [ 'input_ids', 'bbox', 'attention_mask','token_type_ids','image', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns
label_column_index: [ 2, 5 ] # input indices marked as label
output_columns: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'labels' ]
net_input_column_index: [ 0, 1, 2, 3 ] # input indices for network forward func in output_columns
label_column_index: [ 2, 4 ] # input indices marked as label

loader:
shuffle: False
Expand Down
30 changes: 24 additions & 6 deletions mindocr/data/transforms/layoutlm_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict

import cv2
from PIL import Image
import numpy as np

from mindspore import nn
Expand Down Expand Up @@ -65,8 +66,10 @@ def __init__(
super(VQATokenLabelEncode, self).__init__()
tokenizer_dict = {
"LayoutXLM": {"class": LayoutXLMTokenizer, "pretrained_model": "layoutxlm-base-uncased"},
"VI-LayoutXLM": {"class": LayoutXLMTokenizer, "pretrained_model": "layoutxlm-base-uncased"},
}
self.contains_re = contains_re
self.algorithm = algorithm
tokenizer_config = tokenizer_dict[algorithm]
self.tokenizer = tokenizer_config["class"].from_pretrained(tokenizer_config["pretrained_model"]) # to replace
self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
Expand Down Expand Up @@ -141,8 +144,10 @@ def __call__(self, data):
train_re = self.contains_re and not self.infer_mode
if train_re:
ocr_info = self.filter_empty_contents(ocr_info)

height, width, _ = data["image"].shape
if self.algorithm == "VI-LayoutXLM":
width, height = Image.open(data["img_path"]).size
elif self.algorithm == "LayoutXLM":
height, width, _ = data["image"].shape

words_list = []
bbox_list = []
Expand Down Expand Up @@ -258,10 +263,23 @@ def trans_poly_to_bbox(poly):
return [x1, y1, x2, y2]

def _load_ocr_info(self, data):
"""read text info from 'label' data"""
info = data["label"]
info_dict = json.loads(info)
return info_dict
if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data["image"], cls=False)[0]
ocr_info = []
for res in ocr_result:
ocr_info.append(
{
"transcription": res[1][0],
"bbox": self.trans_poly_to_bbox(res[0]),
"points": res[0],
}
)
return ocr_info
else:
info = data["label"]
# read text info
info_dict = json.loads(info)
return info_dict

@staticmethod
def _smooth_box(bboxes, height, width):
Expand Down

0 comments on commit 0780699

Please sign in to comment.