Skip to content

Commit

Permalink
[tokenizers] Add int32 option to encoding (#3571)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Jan 13, 2025
1 parent df77982 commit f0ee0ad
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.huggingface.tokenizers;

import ai.djl.huggingface.tokenizers.jni.CharSpan;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;

Expand Down Expand Up @@ -55,23 +56,85 @@ protected Encoding(
this.overflowing = overflowing;
}

/**
* Returns the {@link NDList} representation of the encodings.
*
* @param encodings the {@code Encoding} batch
* @param manager the {@link NDManager} to create the NDList
* @param withTokenType true to include the token type id
* @param int32 true to use int32 datatype
* @return the {@link NDList}
*/
public static NDList toNDList(
Encoding[] encodings, NDManager manager, boolean withTokenType, boolean int32) {
NDList list = new NDList();
if (!int32) {
long[][] ids = new long[encodings.length][];
long[][] attentionMask = new long[encodings.length][];
long[][] typeIds = new long[encodings.length][];
for (int i = 0; i < encodings.length; i++) {
ids[i] = encodings[i].getIds();
attentionMask[i] = encodings[i].getAttentionMask();
if (withTokenType) {
typeIds[i] = encodings[i].getTypeIds();
}
}
list.add(manager.create(ids));
NDArray inputAttentionMask = manager.create(attentionMask);
list.add(inputAttentionMask);
if (withTokenType) {
list.add(manager.create(typeIds));
}
return list;
}

int[][] ids = new int[encodings.length][];
int[][] attentionMask = new int[encodings.length][];
int[][] typeIds = new int[encodings.length][];
for (int i = 0; i < encodings.length; i++) {
ids[i] = Arrays.stream(encodings[i].getIds()).mapToInt(l -> (int) l).toArray();
attentionMask[i] =
Arrays.stream(encodings[i].getAttentionMask()).mapToInt(l -> (int) l).toArray();
if (withTokenType) {
typeIds[i] =
Arrays.stream(encodings[i].getTypeIds()).mapToInt(l -> (int) l).toArray();
}
}
list.add(manager.create(ids));
NDArray inputAttentionMask = manager.create(attentionMask);
list.add(inputAttentionMask);
if (withTokenType) {
list.add(manager.create(typeIds));
}
return list;
}

/**
* Returns the {@link NDList} representation of the encoding.
*
* @param manager the {@link NDManager} to create the NDList
* @param withTokenType true to include the token type id
* @param int32 true to use int32 datatype
* @return the {@link NDList}
*/
public NDList toNDList(NDManager manager, boolean withTokenType) {
public NDList toNDList(NDManager manager, boolean withTokenType, boolean int32) {
// Converting encoding to int32 NDList because candle can't convert int64 to fp16 in cuda
NDList list = new NDList(withTokenType ? 3 : 2);
int[] intIds = Arrays.stream(ids).mapToInt(i -> (int) i).toArray();
int[] intAttentionMask = Arrays.stream(attentionMask).mapToInt(i -> (int) i).toArray();
list.add(manager.create(intIds));
list.add(manager.create(intAttentionMask));
if (withTokenType) {
int[] intTypeIds = Arrays.stream(typeIds).mapToInt(i -> (int) i).toArray();
list.add(manager.create(intTypeIds));
if (int32) {
int[] intIds = Arrays.stream(ids).mapToInt(i -> (int) i).toArray();
int[] intAttentionMask = Arrays.stream(attentionMask).mapToInt(i -> (int) i).toArray();
list.add(manager.create(intIds));
list.add(manager.create(intAttentionMask));
if (withTokenType) {
int[] intTypeIds = Arrays.stream(typeIds).mapToInt(i -> (int) i).toArray();
list.add(manager.create(intTypeIds));
}
} else {
list.add(manager.create(ids));
list.add(manager.create(attentionMask));
if (withTokenType) {
list.add(manager.create(typeIds));
}
}
return list;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@ public class CrossEncoderTranslator implements Translator<StringPair, float[]> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private boolean int32;
private boolean sigmoid;
private Batchifier batchifier;

CrossEncoderTranslator(
HuggingFaceTokenizer tokenizer,
boolean includeTokenTypes,
boolean int32,
boolean sigmoid,
Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.int32 = int32;
this.sigmoid = sigmoid;
this.batchifier = batchifier;
}
Expand All @@ -60,7 +63,7 @@ public Batchifier getBatchifier() {
public NDList processInput(TranslatorContext ctx, StringPair input) {
Encoding encoding = tokenizer.encode(input.getKey(), input.getValue());
ctx.setAttachment("encoding", encoding);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes, int32);
}

/** {@inheritDoc} */
Expand All @@ -71,7 +74,7 @@ public NDList batchProcessInput(TranslatorContext ctx, List<StringPair> inputs)
Encoding[] encodings = tokenizer.batchEncode(list);
NDList[] batch = new NDList[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes, int32);
}
return batchifier.batchify(batch);
}
Expand Down Expand Up @@ -145,6 +148,7 @@ public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private boolean int32;
private boolean sigmoid = true;
private Batchifier batchifier = Batchifier.STACK;

Expand All @@ -163,6 +167,17 @@ public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
return this;
}

/**
* Sets if use int32 datatype for the {@link Translator}.
*
* @param int32 true to include token types
* @return this builder
*/
public Builder optInt32(boolean int32) {
this.int32 = int32;
return this;
}

/**
* Sets if apply sigmoid for the {@link Translator}.
*
Expand Down Expand Up @@ -192,6 +207,7 @@ public Builder optBatchifier(Batchifier batchifier) {
*/
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
optInt32(ArgumentsUtil.booleanValue(arguments, "int32"));
optSigmoid(ArgumentsUtil.booleanValue(arguments, "sigmoid", true));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optBatchifier(Batchifier.fromString(batchifierStr));
Expand All @@ -204,7 +220,8 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public CrossEncoderTranslator build() throws IOException {
return new CrossEncoderTranslator(tokenizer, includeTokenTypes, sigmoid, batchifier);
return new CrossEncoderTranslator(
tokenizer, includeTokenTypes, int32, sigmoid, batchifier);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,21 @@ public class FillMaskTranslator implements Translator<String, Classifications> {
private long maskTokenId;
private int topK;
private boolean includeTokenTypes;
private boolean int32;
private Batchifier batchifier;

FillMaskTranslator(
HuggingFaceTokenizer tokenizer,
String maskToken,
int topK,
boolean includeTokenTypes,
boolean int32,
Batchifier batchifier) {
this.tokenizer = tokenizer;
this.maskToken = maskToken;
this.topK = topK;
this.includeTokenTypes = includeTokenTypes;
this.int32 = int32;
this.batchifier = batchifier;
Encoding encoding = tokenizer.encode(maskToken, false, false);
maskTokenId = encoding.getIds()[0];
Expand All @@ -68,7 +71,7 @@ public NDList processInput(TranslatorContext ctx, String input) throws Translate
long[] indices = encoding.getIds();
int maskIndex = getMaskIndex(indices);
ctx.setAttachment("maskIndex", maskIndex);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes, int32);
}

/** {@inheritDoc} */
Expand All @@ -83,7 +86,7 @@ public NDList batchProcessInput(TranslatorContext ctx, List<String> inputs)
for (int i = 0; i < batch.length; ++i) {
long[] indices = encodings[i].getIds();
maskIndices[i] = getMaskIndex(indices);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes, int32);
}
return batchifier.batchify(batch);
}
Expand Down Expand Up @@ -167,6 +170,7 @@ public static final class Builder {
private String maskedToken = "[MASK]";
private int topK = 5;
private boolean includeTokenTypes;
private boolean int32;
private Batchifier batchifier = Batchifier.STACK;

Builder(HuggingFaceTokenizer tokenizer) {
Expand Down Expand Up @@ -206,6 +210,17 @@ public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
return this;
}

/**
* Sets if use int32 datatype for the {@link Translator}.
*
* @param int32 true to include token types
* @return this builder
*/
public Builder optInt32(boolean int32) {
this.int32 = int32;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
Expand All @@ -224,6 +239,7 @@ public Builder optBatchifier(Batchifier batchifier) {
*/
public void configure(Map<String, ?> arguments) {
optMaskToken(ArgumentsUtil.stringValue(arguments, "maskToken", "[MASK]"));
optInt32(ArgumentsUtil.booleanValue(arguments, "int32"));
optTopK(ArgumentsUtil.intValue(arguments, "topK", 5));
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
Expand All @@ -238,7 +254,7 @@ public void configure(Map<String, ?> arguments) {
*/
public FillMaskTranslator build() throws IOException {
return new FillMaskTranslator(
tokenizer, maskedToken, topK, includeTokenTypes, batchifier);
tokenizer, maskedToken, topK, includeTokenTypes, int32, batchifier);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,19 @@ public class QuestionAnsweringTranslator implements Translator<QAInput, String>

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private boolean int32;
private Batchifier batchifier;
private boolean detail;

QuestionAnsweringTranslator(
HuggingFaceTokenizer tokenizer,
boolean includeTokenTypes,
boolean int32,
Batchifier batchifier,
boolean detail) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.int32 = int32;
this.batchifier = batchifier;
this.detail = detail;
}
Expand All @@ -62,7 +65,7 @@ public Batchifier getBatchifier() {
public NDList processInput(TranslatorContext ctx, QAInput input) {
Encoding encoding = tokenizer.encode(input.getQuestion(), input.getParagraph());
ctx.setAttachment("encoding", encoding);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes, int32);
}

/** {@inheritDoc} */
Expand All @@ -77,7 +80,7 @@ public NDList batchProcessInput(TranslatorContext ctx, List<QAInput> inputs) {
ctx.setAttachment("encodings", encodings);
NDList[] batch = new NDList[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes, int32);
}
return batchifier.batchify(batch);
}
Expand Down Expand Up @@ -190,6 +193,7 @@ public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private boolean int32;
private Batchifier batchifier = Batchifier.STACK;
private boolean detail;

Expand All @@ -208,6 +212,17 @@ public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
return this;
}

/**
* Sets if use int32 datatype for the {@link Translator}.
*
* @param int32 true to include token types
* @return this builder
*/
public Builder optInt32(boolean int32) {
this.int32 = int32;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
Expand Down Expand Up @@ -237,6 +252,7 @@ public Builder optDetail(boolean detail) {
*/
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
optInt32(ArgumentsUtil.booleanValue(arguments, "int32"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optDetail(ArgumentsUtil.booleanValue(arguments, "detail"));
optBatchifier(Batchifier.fromString(batchifierStr));
Expand All @@ -250,7 +266,7 @@ public void configure(Map<String, ?> arguments) {
*/
public QuestionAnsweringTranslator build() throws IOException {
return new QuestionAnsweringTranslator(
tokenizer, includeTokenTypes, batchifier, detail);
tokenizer, includeTokenTypes, int32, batchifier, detail);
}
}
}
Loading

0 comments on commit f0ee0ad

Please sign in to comment.