Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
remove padding warning (#916)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin authored and sxjscience committed Sep 3, 2019
1 parent fbd7527 commit a563293
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions scripts/bert/finetune_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,12 @@ def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, pad=Fa
data_train_len = data_train.transform(
lambda input_id, length, segment_id, label_id: length, lazy=False)
# bucket sampler for training
pad_val = vocabulary[vocabulary.padding_token]
batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Pad(axis=0), nlp.data.batchify.Stack(),
nlp.data.batchify.Pad(axis=0), nlp.data.batchify.Stack(label_dtype))
nlp.data.batchify.Pad(axis=0, pad_val=pad_val), # input
nlp.data.batchify.Stack(), # length
nlp.data.batchify.Pad(axis=0, pad_val=0), # segment
nlp.data.batchify.Stack(label_dtype)) # label
batch_sampler = nlp.data.sampler.FixedBucketSampler(
data_train_len,
batch_size=batch_size,
Expand Down Expand Up @@ -327,8 +330,8 @@ def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, pad=Fa

# batchify for data test
test_batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Pad(axis=0), nlp.data.batchify.Stack(),
nlp.data.batchify.Pad(axis=0))
nlp.data.batchify.Pad(axis=0, pad_val=pad_val), nlp.data.batchify.Stack(),
nlp.data.batchify.Pad(axis=0, pad_val=0))
# transform for data test
test_trans = BERTDatasetTransform(tokenizer, max_len,
class_labels=None,
Expand Down

0 comments on commit a563293

Please sign in to comment.