Skip to content

Commit

Permalink
handle surrogate pairs, add tests for composite chars or surrogate pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
huan233usc committed Jan 2, 2025
1 parent ad73031 commit 5b3ec6b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ public String getString(int rowId) {
String inputString = input.getString(rowId);
int position = positionVector.getInt(rowId);
Optional<Integer> length = lengthVector.map(columnVector -> columnVector.getInt(rowId));
if (position > inputString.length() || (length.isPresent() && length.get() < 1)) {
if (position > getStringLengthWithCodePoint(inputString)
|| (length.isPresent() && length.get() < 1)) {
return "";
}
int startPosition = buildStartPosition(inputString, position);
Expand All @@ -116,10 +117,13 @@ public String getString(int rowId) {
len -> {
// endIndex should be less than the length of input string, but positive.
// e.g. Substring("aaa", -100, 95), should be read as Substring("aaa", 0, 0)
int endIndex = Math.min(inputString.length(), Math.max(startPosition + len, 0));
return inputString.substring(startIndex, endIndex);
int endIndex =
Math.min(
getStringLengthWithCodePoint(inputString),
Math.max(startPosition + len, 0));
return subStringWithCodePoint(inputString, startIndex, Optional.of(endIndex));
})
.orElse(inputString.substring(startIndex));
.orElse(subStringWithCodePoint(inputString, startIndex, Optional.empty()));
}
};
}
Expand All @@ -136,9 +140,20 @@ public String getString(int rowId) {
private static int buildStartPosition(String inputString, int pos) {
// Handles the negative position (substring("abc", -2, 1), the start position should be 1("b"))
if (pos < 0) {
return inputString.length() + pos;
return getStringLengthWithCodePoint(inputString) + pos;
}
// Pos is 1 based and pos = 0 is treated as 1.
return Math.max(pos - 1, 0);
}

/** Returns code point based string length for handling surrogate pairs. */
private static int getStringLengthWithCodePoint(String s) {
return s.codePointCount(/* beginIndex = */ 0, s.length());
}

private static String subStringWithCodePoint(String s, int start, Optional<Integer> end) {
int startIndex = s.offsetByCodePoints(/* beginIndex = */ 0, start);
return end.map(e -> s.substring(startIndex, s.offsetByCodePoints(0, e)))
.orElse(s.substring(startIndex));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -774,8 +774,9 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa
)

test("evaluate expression: substring") {
// scalastyle:off nonascii
val data = Seq[String](
null, "one", "two", "three", "four", null, null, "seven", "eight")
null, "one", "two", "three", "four", null, null, "seven", "eight", "😉", "")
val col = stringVector(data)
val col_name = "str_col"
val schema = new StructType().add(col_name, StringType.STRING)
Expand All @@ -795,112 +796,126 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa
checkSubString(
input,
substring(new Column(col_name), 0),
Seq[String](null, "one", "two", "three", "four", null, null, "seven", "eight"))
// scalastyle:off nonascii
Seq[String](null, "one", "two", "three", "four", null, null, "seven", "eight", "😉", ""))

checkSubString(
input,
substring(new Column(col_name), 1),
Seq[String](null, "one", "two", "three", "four", null, null, "seven", "eight"))
// scalastyle:off nonascii
Seq[String](null, "one", "two", "three", "four", null, null, "seven", "eight", "😉", ""))

checkSubString(
input,
substring(new Column(col_name), 2),
Seq[String](null, "ne", "wo", "hree", "our", null, null, "even", "ight"))
Seq[String](null, "ne", "wo", "hree", "our", null, null, "even", "ight", "", "̈"))

checkSubString(
input,
substring(new Column(col_name), -1),
Seq[String](null, "e", "o", "e", "r", null, null, "n", "t"))
// scalastyle:off nonascii
Seq[String](null, "e", "o", "e", "r", null, null, "n", "t", "😉", "̈"))

checkSubString(
input,
substring(new Column(col_name), -1000),
Seq[String](null, "one", "two", "three", "four", null, null, "seven", "eight"))
// scalastyle:off nonascii
Seq[String](null, "one", "two", "three", "four", null, null, "seven", "eight", "😉", ""))

checkSubString(
input,
substring(new Column(col_name), 0, Option(4)),
Seq[String](null, "one", "two", "thre", "four", null, null, "seve", "eigh"))
// scalastyle:off nonascii
Seq[String](null, "one", "two", "thre", "four", null, null, "seve", "eigh", "😉", ""))

checkSubString(
input,
substring(new Column(col_name), 2, Option(0)),
Seq[String](null, "", "", "", "", null, null, "", ""))
Seq[String](null, "", "", "", "", null, null, "", "", "", ""))

checkSubString(
input,
substring(new Column(col_name), 1, Option(1)),
// scalastyle:off nonascii
Seq[String](null, "o", "t", "t", "f", null, null, "s", "e", "😉", "e"))

checkSubString(
input,
substring(new Column(col_name), 2, Option(1)),
Seq[String](null, "n", "w", "h", "o", null, null, "e", "i"))
Seq[String](null, "n", "w", "h", "o", null, null, "e", "i", "", "̈"))

checkSubString(
input,
substring(new Column(col_name), 2, Option(10000)),
Seq[String](null, "ne", "wo", "hree", "our", null, null, "even", "ight"))
Seq[String](null, "ne", "wo", "hree", "our", null, null, "even", "ight", "", "̈"))

checkSubString(
input,
substring(new Column(col_name), 1000),
Seq[String](null, "", "", "", "", null, null, "", ""))
Seq[String](null, "", "", "", "", null, null, "", "", "", ""))

checkSubString(
input,
substring(new Column(col_name), 1000, Option(10000)),
Seq[String](null, "", "", "", "", null, null, "", ""))
Seq[String](null, "", "", "", "", null, null, "", "", "", ""))

checkSubString(
input,
substring(new Column(col_name), 2, Option(-10)),
Seq[String](null, "", "", "", "", null, null, "", ""))
Seq[String](null, "", "", "", "", null, null, "", "", "", ""))

checkSubString(
input,
substring(new Column(col_name), -2, Option(1)),
Seq[String](null, "n", "w", "e", "u", null, null, "e", "h"))
Seq[String](null, "n", "w", "e", "u", null, null, "e", "h", "", "e"))

checkSubString(
input,
substring(new Column(col_name), -2, Option(2)),
Seq[String](null, "ne", "wo", "ee", "ur", null, null, "en", "ht"))
// scalastyle:off nonascii
Seq[String](null, "ne", "wo", "ee", "ur", null, null, "en", "ht", "😉", ""))

checkSubString(
input,
substring(new Column(col_name), -4, Option(3)),
Seq[String](null, "on", "tw", "hre", "fou", null, null, "eve", "igh"))
Seq[String](null, "on", "tw", "hre", "fou", null, null, "eve", "igh", "", "e"))

checkSubString(
input,
substring(new Column(col_name), -100, Option(95)),
Seq[String](null, "", "", "", "", null, null, "", ""))
Seq[String](null, "", "", "", "", null, null, "", "", "", ""))

checkSubString(
input,
substring(new Column(col_name), -100, Option(98)),
Seq[String](null, "o", "t", "thr", "fo", null, null, "sev", "eig"))
Seq[String](null, "o", "t", "thr", "fo", null, null, "sev", "eig", "", ""))

checkSubString(
input,
substring(new Column(col_name), -100, Option(108)),
Seq[String](null, "one", "two", "three", "four", null, null, "seven", "eight"))
// scalastyle:off nonascii
Seq[String](null, "one", "two", "three", "four", null, null, "seven", "eight", "😉", ""))

checkSubString(
input,
substring(new Column(col_name), 2147483647, Option(10000)),
Seq[String](null, "", "", "", "", null, null, "", ""))
Seq[String](null, "", "", "", "", null, null, "", "", "", ""))

checkSubString(
input,
substring(new Column(col_name), 2147483647),
Seq[String](null, "", "", "", "", null, null, "", ""))
Seq[String](null, "", "", "", "", null, null, "", "", "", ""))

checkSubString(
input,
substring(new Column(col_name), -2147483648, Option(10000)),
Seq[String](null, "", "", "", "", null, null, "", ""))
Seq[String](null, "", "", "", "", null, null, "", "", "", ""))

checkSubString(
input,
substring(new Column(col_name), -2147483648),
Seq[String](null, "one", "two", "three", "four", null, null, "seven", "eight"))
// scalastyle:off nonascii
Seq[String](null, "one", "two", "three", "four", null, null, "seven", "eight", "😉", ""))

val outputVectorForEmptyInput = evaluator(
schema,
Expand Down

0 comments on commit 5b3ec6b

Please sign in to comment.