From dc9754b5e11c2d4f6aea8eb6a0b7789da11afb6e Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Mon, 30 Dec 2024 11:30:37 -0800 Subject: [PATCH] resolve comments and remove the support for binary type by changing it to a todo --- .../expressions/DefaultExpressionUtils.java | 19 +++++++ .../expressions/SubstringEvaluator.java | 56 +++++++------------ .../DefaultExpressionEvaluatorSuite.scala | 43 ++++++-------- 3 files changed, 55 insertions(+), 63 deletions(-) diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java index b59db8689ab..20bb8d8a44c 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java @@ -15,12 +15,14 @@ */ package io.delta.kernel.defaults.internal.expressions; +import static io.delta.kernel.defaults.internal.DefaultEngineErrors.unsupportedExpressionException; import static io.delta.kernel.internal.util.Preconditions.checkArgument; import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.Literal; import io.delta.kernel.internal.util.Utils; import io.delta.kernel.types.*; import java.math.BigDecimal; @@ -383,4 +385,21 @@ private ColumnVector getVector(int rowId) { } }; } + + /** + * Checks if the specific expression is an integer literal, throws {@code + * unsupportedExpressionException} if not. + * + * @param expr, expression to be checked. + * @param context string describing the context, used for constructing error message. + * @param baseExpression expression whose evaluation triggers this check. Uued for constructing + * error message. + */ + static void checkIntegerLiteral(Expression expr, String context, Expression baseExpression) { + if ((expr instanceof Literal) && IntegerType.INTEGER.equals(((Literal) expr).getDataType())) { + return; + } + throw unsupportedExpressionException( + baseExpression, String.format("%s, expects an integral numeric", context)); + } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/SubstringEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/SubstringEvaluator.java index 2451d30be92..72cd1a0c799 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/SubstringEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/SubstringEvaluator.java @@ -19,7 +19,6 @@ import io.delta.kernel.data.ColumnVector; import io.delta.kernel.expressions.Expression; -import io.delta.kernel.expressions.Literal; import io.delta.kernel.expressions.ScalarExpression; import io.delta.kernel.internal.util.Utils; import io.delta.kernel.types.*; @@ -28,9 +27,11 @@ /** Utility methods to evaluate {@code substring} expression. */ public class SubstringEvaluator { + private SubstringEvaluator() {} + + // TODO: support binary type. private static final Set SUBSTRING_SUPPORTED_TYPE = - Collections.unmodifiableSet( - new HashSet<>(Arrays.asList(StringType.STRING, BinaryType.BINARY))); + Collections.unmodifiableSet(new HashSet<>(Collections.singletonList(StringType.STRING))); /** Validates and transforms the {@code substring} expression. */ static ScalarExpression validateAndTransform( @@ -47,25 +48,17 @@ static ScalarExpression validateAndTransform( if (!SUBSTRING_SUPPORTED_TYPE.contains(childrenOutputTypes.get(0))) { throw unsupportedExpressionException( - substring, "Invalid type of first input of SUBSTRING: expects BINARY or STRING"); + substring, "Invalid type of first input of SUBSTRING: expects STRING"); } Expression posExpression = childrenExpressions.get(1); - if (!isIntegerLiteral(posExpression)) { - throw unsupportedExpressionException( - substring, - "Invalid type of second input of SUBSTRING: " - + "expects an integral numeric expression specifying the starting position."); - } - + DefaultExpressionUtils.checkIntegerLiteral( + posExpression, /* context= */ "Invalid `pos` argument type for SUBSTRING", substring); if (childrenSize == 3) { Expression lengthExpression = childrenExpressions.get(2); - if (!isIntegerLiteral(lengthExpression)) { - throw unsupportedExpressionException( - substring, "Invalid type of third input of SUBSTRING: expects an integral numeric."); - } + DefaultExpressionUtils.checkIntegerLiteral( + lengthExpression, /* context= */ "Invalid `len` argument type for SUBSTRING", substring); } - return new ScalarExpression(substring.getName(), childrenExpressions); } @@ -92,28 +85,28 @@ public int getSize() { @Override public void close() { - if (lengthVector.isPresent()) { - Utils.closeCloseables(input, positionVector, lengthVector.get()); - } else { - Utils.closeCloseables(input, positionVector); - } + // Utils.closeCloseables method will ignore the null element. + Utils.closeCloseables(input, positionVector, lengthVector.orElse(null)); } @Override public boolean isNullAt(int rowId) { + if (rowId < 0 || rowId >= getSize()) { + throw new IllegalArgumentException( + String.format( + "Unexpected rowId %d, expected between 0 and the size of the column vector", + rowId)); + } return input.isNullAt(rowId); } @Override public String getString(int rowId) { - if (isNullAt(rowId) || rowId < 0 || rowId >= getSize()) { + if (isNullAt(rowId)) { return null; } - String inputString = - input.getDataType() == BinaryType.BINARY - ? new String(input.getBinary(rowId)) - : input.getString(rowId); + String inputString = input.getString(rowId); int position = positionVector.getInt(rowId); Optional length = lengthVector.map(columnVector -> columnVector.getInt(rowId)); @@ -145,20 +138,11 @@ public String getString(int rowId) { * is not normalized, i.e. could be less than 0. */ private static int buildStartPosition(String inputString, int pos) { + // Handles the negative position (substring("abc", -2, 1), the start position should be 1("b")) if (pos < 0) { return inputString.length() + pos; } // Pos is 1 based and pos = 0 is treated as 1. return Math.max(pos - 1, 0); } - - private static boolean isIntegerLiteral(Expression expression) { - if (!(expression instanceof Literal)) { - return false; - } - Literal literal = (Literal) expression; - return IntegerType.INTEGER.equals(literal.getDataType()); - } - - private SubstringEvaluator() {} } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala index 33ccd97e307..c80b1306267 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala @@ -776,12 +776,10 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa test("evaluate expression: substring") { val data = Seq[String]( null, "one", "two", "three", "four", null, null, "seven", "eight") - val col1 = stringVector(data) - val col2 = binaryVector(data.map(str => if (str != null) str.getBytes else null)) - val schema = new StructType() - .add("str_col", StringType.STRING) - .add("binary_col", BinaryType.BINARY) - val input = new DefaultColumnarBatch(col1.getSize, schema, Array(col1, col2)) + val col = stringVector(data) + val col_name = "str_col" + val schema = new StructType().add(col_name, StringType.STRING) + val input = new DefaultColumnarBatch(col.getSize, schema, Array(col)) def checkSubString( input: DefaultColumnarBatch, @@ -794,7 +792,6 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa checkStringVectors(actOutputVector, expOutputVector) } - Seq("str_col", "binary_col").foreach(col_name => { checkSubString( input, substring(new Column(col_name), 0), @@ -879,13 +876,12 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa input, substring(new Column(col_name), -100, Option(108)), Seq[String](null, "one", "two", "three", "four", null, null, "seven", "eight")) - }) val outputVectorForEmptyInput = evaluator( schema, new ScalarExpression("SUBSTRING", util.Arrays.asList( - new Column("str_col"), Literal.ofInt(1), Literal.ofInt(1))), + new Column(col_name), Literal.ofInt(1), Literal.ofInt(1))), StringType.STRING ).eval( new DefaultColumnarBatch(/* size= */0, schema, @@ -896,36 +892,35 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa def checkUnsupportedColumnTypes(colType: DataType): Unit = { val schema = new StructType() - .add("col1", colType) + .add(col_name, colType) val batch = new DefaultColumnarBatch(5, schema, Array(testColumnVector(5, colType))) val e = intercept[UnsupportedOperationException] { evaluator( schema, new ScalarExpression("SUBSTRING", - util.Arrays.asList(new Column("col1"), Literal.ofInt(1))), + util.Arrays.asList(new Column(col_name), Literal.ofInt(1))), StringType.STRING ).eval(batch) } assert( - e.getMessage.contains("Invalid type of first input of SUBSTRING: expects BINARY or STRING")) + e.getMessage.contains("Invalid type of first input of SUBSTRING: expects STRING")) } checkUnsupportedColumnTypes(IntegerType.INTEGER) checkUnsupportedColumnTypes(ByteType.BYTE) checkUnsupportedColumnTypes(BooleanType.BOOLEAN) + checkUnsupportedColumnTypes(BinaryType.BINARY) val badLiteralSize = intercept[UnsupportedOperationException] { evaluator( schema, new ScalarExpression("SUBSTRING", util.Arrays.asList( - new Column("str_col"), Literal.ofInt(1), Literal.ofInt(1), Literal.ofInt(1))), + new Column(col_name), Literal.ofInt(1), Literal.ofInt(1), Literal.ofInt(1))), StringType.STRING ).eval( new DefaultColumnarBatch(/* size= */5, schema, - Array( - testColumnVector(/* size= */5, StringType.STRING), - testColumnVector(/* size= */5, BinaryType.BINARY)))) + Array(testColumnVector(/* size= */5, StringType.STRING)))) } assert( badLiteralSize.getMessage.contains( @@ -940,28 +935,22 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa StringType.STRING ).eval( new DefaultColumnarBatch(/* size= */5, schema, - Array( - testColumnVector(/* size= */5, StringType.STRING), - testColumnVector(/* size= */5, BinaryType.BINARY)))) + Array(testColumnVector(/* size= */5, StringType.STRING)))) } - assert( - badPosType.getMessage.contains("Invalid type of second input of SUBSTRING")) + assert(badPosType.getMessage.contains("Invalid `pos` argument type for SUBSTRING")) val badLenType = intercept[UnsupportedOperationException] { evaluator( schema, new ScalarExpression("SUBSTRING", util.Arrays.asList( - new Column("str_col"), Literal.ofInt(1), Literal.ofBoolean(true))), + new Column(col_name), Literal.ofInt(1), Literal.ofBoolean(true))), StringType.STRING ).eval( new DefaultColumnarBatch(/* size= */5, schema, - Array( - testColumnVector(/* size= */5, StringType.STRING), - testColumnVector(/* size= */5, BinaryType.BINARY)))) + Array(testColumnVector(/* size= */5, StringType.STRING)))) } - assert( - badLenType.getMessage.contains("Invalid type of third input of SUBSTRING")) + assert(badLenType.getMessage.contains("Invalid `len` argument type for SUBSTRING")) } test("evaluate expression: comparators `byte` with other implicit types") {