Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Add STARTS_WITH expression and default impl #4007

Merged
merged 5 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@
* <li>SQL semantic: <code>expr1 IS NOT DISTINCT FROM expr2</code>
* <li>Since version: 3.3.0
* </ul>
* <li>Name: <code>STARTS_WITH</code>
* <ul>
* <li>SQL semantic: <code>expr STARTS_WITH expr</code>
* <li>Since version: 3.4.0
* </ul>
* </ol>
*
* @since 3.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,18 @@ ExpressionTransformResult visitLike(final Predicate like) {
return new ExpressionTransformResult(transformedExpression, BooleanType.BOOLEAN);
}

@Override
ExpressionTransformResult visitStartsWith(Predicate startsWith) {
List<ExpressionTransformResult> children =
startsWith.getChildren().stream().map(this::visit).collect(toList());
Predicate transformedExpression =
StartsWithExpressionEvaluator.validateAndTransform(
startsWith,
children.stream().map(e -> e.expression).collect(toList()),
children.stream().map(e -> e.outputType).collect(toList()));
return new ExpressionTransformResult(transformedExpression, BooleanType.BOOLEAN);
}

private Predicate validateIsPredicate(
Expression baseExpression, ExpressionTransformResult result) {
checkArgument(
Expand Down Expand Up @@ -610,6 +622,12 @@ ColumnVector visitLike(final Predicate like) {
children, children.stream().map(this::visit).collect(toList()));
}

@Override
ColumnVector visitStartsWith(Predicate startsWith) {
return StartsWithExpressionEvaluator.eval(
startsWith.getChildren().stream().map(this::visit).collect(toList()));
}

/**
* Utility method to evaluate inputs to the binary input expression. Also validates the
* evaluated expression result {@link ColumnVector}s are of the same size.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
*/
package io.delta.kernel.defaults.internal.expressions;

import static io.delta.kernel.defaults.internal.DefaultEngineErrors.unsupportedExpressionException;
import static io.delta.kernel.internal.util.Preconditions.checkArgument;

import io.delta.kernel.data.ArrayValue;
import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.data.MapValue;
import io.delta.kernel.expressions.Expression;
import io.delta.kernel.expressions.Literal;
import io.delta.kernel.internal.util.Utils;
import io.delta.kernel.types.*;
import java.math.BigDecimal;
Expand Down Expand Up @@ -383,4 +385,29 @@ private ColumnVector getVector(int rowId) {
}
};
}

/**
* Checks the argument count of an expression. throws {@code unsupportedExpressionException} if
* argument count mismatched.
*/
static void checkArgsCount(Expression expr, int expectedCount, String exprName, String context) {
if (expr.getChildren().size() != expectedCount) {
throw unsupportedExpressionException(
expr, String.format("Invalid number of inputs of %s expression, %s", exprName, context));
}
}

static void checkIsStringType(DataType dataType, Expression parentExpr, String errorMessage) {
if (StringType.STRING.equals(dataType)) {
return;
}
throw unsupportedExpressionException(parentExpr, errorMessage);
}

static void checkIsLiteral(Expression expr, Expression parentExpr, String errorMessage) {
if (expr instanceof Literal) {
return;
}
throw unsupportedExpressionException(parentExpr, errorMessage);
huan233usc marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ abstract class ExpressionVisitor<R> {

abstract R visitLike(Predicate predicate);

abstract R visitStartsWith(Predicate predicate);

final R visit(Expression expression) {
if (expression instanceof PartitionValueExpression) {
return visitPartitionValue((PartitionValueExpression) expression);
Expand Down Expand Up @@ -113,6 +115,8 @@ private R visitScalarExpression(ScalarExpression expression) {
return visitTimeAdd(expression);
case "LIKE":
return visitLike(new Predicate(name, children));
case "STARTS_WITH":
return visitStartsWith(new Predicate(name, children));
default:
throw new UnsupportedOperationException(
String.format("Scalar expression `%s` is not supported.", name));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (2023) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.delta.kernel.defaults.internal.expressions;

import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.*;

import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.expressions.Expression;
import io.delta.kernel.expressions.Predicate;
import io.delta.kernel.internal.util.Utils;
import io.delta.kernel.types.BooleanType;
import io.delta.kernel.types.DataType;
import java.util.List;

public class StartsWithExpressionEvaluator {

/** Validates and transforms the {@code starts_with} expression. */
static Predicate validateAndTransform(
Predicate startsWith,
List<Expression> childrenExpressions,
List<DataType> childrenOutputTypes) {
checkArgsCount(
startsWith,
/* expectedCount= */ 2,
startsWith.getName(),
"Example usage: STARTS_WITH(column, 'test')");
for (DataType dataType : childrenOutputTypes) {
checkIsStringType(dataType, startsWith, "'STARTS_WITH' expects STRING type inputs");
}

// TODO: support non literal as the second input of starts with.
checkIsLiteral(
childrenExpressions.get(1),
startsWith,
"'STARTS_WITH' expects literal as the second input");
return new Predicate(startsWith.getName(), childrenExpressions);
}

static ColumnVector eval(List<ColumnVector> childrenVectors) {
return new ColumnVector() {
final ColumnVector left = childrenVectors.get(0);
final ColumnVector right = childrenVectors.get(1);

@Override
public DataType getDataType() {
return BooleanType.BOOLEAN;
}

@Override
public int getSize() {
return left.getSize();
}

@Override
public void close() {
Utils.closeCloseables(left, right);
}

@Override
public boolean getBoolean(int rowId) {
if (isNullAt(rowId)) {
// The return value is undefined and can be anything, if the slot for rowId is null.
return false;
}
return left.getString(rowId).startsWith(right.getString(rowId));
}

@Override
public boolean isNullAt(int rowId) {
return left.isNullAt(rowId) || right.isNullAt(rowId);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,80 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa
"LIKE expression has invalid escape sequence"))
}

test("evaluate expression: starts with") {
val col1 = stringVector(Seq[String]("one", "two", "t%hree", "four", null, null, "%"))
val col2 = stringVector(Seq[String]("o", "t", "T", "4", "f", null, null))
val schema = new StructType()
.add("col1", StringType.STRING)
.add("col2", StringType.STRING)
val input = new DefaultColumnarBatch(col1.getSize, schema, Array(col1, col2))

val startsWithExpressionLiteral = startsWith(new Column("col1"), Literal.ofString("t%"))
val expOutputVectorLiteral =
booleanVector(Seq[BooleanJ](false, false, true, false, null, null, false))
checkBooleanVectors(new DefaultExpressionEvaluator(
schema, startsWithExpressionLiteral, BooleanType.BOOLEAN).eval(input), expOutputVectorLiteral)

val startsWithExpressionNullLiteral = startsWith(new Column("col1"), Literal.ofString(null))
val allNullVector =
booleanVector(Seq[BooleanJ](null, null, null, null, null, null, null))
checkBooleanVectors(new DefaultExpressionEvaluator(
schema, startsWithExpressionNullLiteral, BooleanType.BOOLEAN).eval(input), allNullVector)

// Two literal expressions on both sides
val startsWithExpressionAlwaysTrue = startsWith(Literal.ofString("ABC"), Literal.ofString("A"))
val allTrueVector = booleanVector(Seq[BooleanJ](true, true, true, true, true, true, true))
checkBooleanVectors(new DefaultExpressionEvaluator(
schema, startsWithExpressionAlwaysTrue, BooleanType.BOOLEAN).eval(input), allTrueVector)

val startsWithExpressionAlwaysFalse =
startsWith(Literal.ofString("ABC"), Literal.ofString("_B%"))
val allFalseVector =
booleanVector(Seq[BooleanJ](false, false, false, false, false, false, false))
checkBooleanVectors(new DefaultExpressionEvaluator(
schema, startsWithExpressionAlwaysFalse, BooleanType.BOOLEAN).eval(input), allFalseVector)

// scalastyle:off nonascii
val colUnicode = stringVector(Seq[String]("中文", "中", "文"))
val schemaUnicode = new StructType().add("col", StringType.STRING)
val inputUnicode = new DefaultColumnarBatch(colUnicode.getSize,
schemaUnicode, Array(colUnicode))
val startsWithExpressionUnicode = startsWith(new Column("col"), Literal.ofString("中"))
val expOutputVectorLiteralUnicode = booleanVector(Seq[BooleanJ](true, true, false))
checkBooleanVectors(new DefaultExpressionEvaluator(schemaUnicode,
startsWithExpressionUnicode,
BooleanType.BOOLEAN).eval(inputUnicode), expOutputVectorLiteralUnicode)

val startsWithExpressionExpression = startsWith(new Column("col1"), new Column("col2"))
huan233usc marked this conversation as resolved.
Show resolved Hide resolved
val e = intercept[UnsupportedOperationException] {
new DefaultExpressionEvaluator(
schema, startsWithExpressionExpression, BooleanType.BOOLEAN).eval(input)
}
assert(e.getMessage.contains("'STARTS_WITH' expects literal as the second input"))


def checkUnsupportedTypes(colType: DataType, literalType: DataType): Unit = {
val schema = new StructType()
.add("col", colType)
val expr = startsWith(new Column("col"), Literal.ofNull(literalType))
val input = new DefaultColumnarBatch(5, schema,
Array(testColumnVector(5, colType)))

val e = intercept[UnsupportedOperationException] {
new DefaultExpressionEvaluator(
schema, expr, BooleanType.BOOLEAN).eval(input)
}
assert(e.getMessage.contains("'STARTS_WITH' expects STRING type inputs"))
}

checkUnsupportedTypes(BooleanType.BOOLEAN, BooleanType.BOOLEAN)
checkUnsupportedTypes(LongType.LONG, LongType.LONG)
checkUnsupportedTypes(IntegerType.INTEGER, IntegerType.INTEGER)
checkUnsupportedTypes(StringType.STRING, BooleanType.BOOLEAN)
checkUnsupportedTypes(StringType.STRING, IntegerType.INTEGER)
checkUnsupportedTypes(StringType.STRING, LongType.LONG)
}

test("evaluate expression: comparators (=, <, <=, >, >=)") {
val ASCII_MAX_CHARACTER = '\u007F'
val UTF8_MAX_CHARACTER = new String(Character.toChars(Character.MAX_CODE_POINT))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ trait ExpressionSuiteBase extends TestUtils with DefaultVectorTestUtils {
new Predicate("like", children.asJava)
}

protected def startsWith(left: Expression, right: Expression): Predicate = {
new Predicate("starts_with", left, right)
}

protected def comparator(symbol: String, left: Expression, right: Expression): Predicate = {
new Predicate(symbol, left, right)
}
Expand Down
Loading