Skip to content

Commit

Permalink
[CALCITE-6550] Improve SQL function overloading
Browse files Browse the repository at this point in the history
* RexImpTable can better handle collisions in the scalar function map
* Return the implementor if only one implementor is found for an operator key
* If there are multiple implementors for an operator key, look for one with the exact same operator
* An operator key is the operator name and kind
  • Loading branch information
normanj-bitquill committed Sep 10, 2024
1 parent 83d2dd9 commit cc6ec96
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
import org.apache.calcite.rex.RexPatternFieldRef;
import org.apache.calcite.rex.RexWindowExclusion;
import org.apache.calcite.runtime.FlatLists;
import org.apache.calcite.runtime.ImmutablePairList;
import org.apache.calcite.runtime.PairList;
import org.apache.calcite.runtime.SqlFunctions;
import org.apache.calcite.schema.FunctionContext;
import org.apache.calcite.schema.ImplementableAggFunction;
Expand Down Expand Up @@ -543,15 +545,20 @@ public class RexImpTable {
public static final MemberExpression BOXED_TRUE_EXPR =
Expressions.field(null, Boolean.class, "TRUE");

private final ImmutableMap<SqlOperator, RexCallImplementor> map;
private final ImmutableMap<SqlOperator, ImmutablePairList<SqlOperator, RexCallImplementor>> map;
private final ImmutableMap<SqlAggFunction, Supplier<? extends AggImplementor>> aggMap;
private final ImmutableMap<SqlAggFunction, Supplier<? extends WinAggImplementor>> winAggMap;
private final ImmutableMap<SqlMatchFunction, Supplier<? extends MatchImplementor>> matchMap;
private final ImmutableMap<SqlOperator, Supplier<? extends TableFunctionCallImplementor>>
tvfImplementorMap;

private RexImpTable(Builder builder) {
this.map = ImmutableMap.copyOf(builder.map);
final ImmutableMap.Builder<SqlOperator, ImmutablePairList<SqlOperator, RexCallImplementor>>
mapBuilder = ImmutableMap.builder();
builder.map.forEach((k, v) -> {
mapBuilder.put(k, v.immutable());
});
this.map = ImmutableMap.copyOf(mapBuilder.build());
this.aggMap = ImmutableMap.copyOf(builder.aggMap);
this.winAggMap = ImmutableMap.copyOf(builder.winAggMap);
this.matchMap = ImmutableMap.copyOf(builder.matchMap);
Expand Down Expand Up @@ -848,7 +855,6 @@ void populate1() {
new SafeArithmeticImplementor(BuiltInMethod.SAFE_SUBTRACT.method));

define(PI, new PiImplementor());
populate2();
}

/** Second step of population. */
Expand Down Expand Up @@ -1277,7 +1283,8 @@ private static <T> Supplier<T> constructorSupplier(Class<T> klass) {

/** Holds intermediate state from which a RexImpTable can be constructed. */
private static class Builder extends AbstractBuilder {
private final Map<SqlOperator, RexCallImplementor> map = new HashMap<>();
private final Map<SqlOperator, PairList<SqlOperator, RexCallImplementor>> map =
new HashMap<>();
private final Map<SqlAggFunction, Supplier<? extends AggImplementor>> aggMap =
new HashMap<>();
private final Map<SqlAggFunction, Supplier<? extends WinAggImplementor>> winAggMap =
Expand All @@ -1288,13 +1295,28 @@ private static class Builder extends AbstractBuilder {
tvfImplementorMap = new HashMap<>();

@Override protected RexCallImplementor get(SqlOperator operator) {
return requireNonNull(map.get(operator),
() -> "no implementor for " + operator);
final PairList<SqlOperator, RexCallImplementor> implementors =
requireNonNull(map.get(operator));
if (implementors.size() == 1) {
return implementors.get(0).getValue();
} else {
for (Map.Entry<SqlOperator, RexCallImplementor> entry : implementors) {
if (operator == entry.getKey()) {
return entry.getValue();
}
}
throw new NullPointerException();
}
}

@Override <T extends RexCallImplementor> T define(SqlOperator operator,
T implementor) {
map.put(operator, requireNonNull(implementor, "implementor"));
if (map.containsKey(operator)) {
map.get(operator).add(operator, implementor);
} else {
map.put(operator, PairList.<SqlOperator, RexCallImplementor>builder()
.add(operator, implementor).build());
}
return implementor;
}

Expand Down Expand Up @@ -1369,9 +1391,27 @@ private static RexCallImplementor wrapAsRexCallImplementor(
((ImplementableFunction) udf).getImplementor();
return wrapAsRexCallImplementor(implementor);
} else if (operator instanceof SqlTypeConstructorFunction) {
return map.get(SqlStdOperatorTable.ROW);
final ImmutablePairList<SqlOperator, RexCallImplementor> implementors =
map.get(SqlStdOperatorTable.ROW);
if (implementors != null && implementors.size() == 1) {
return implementors.get(0).getValue();
}
} else {
final ImmutablePairList<SqlOperator, RexCallImplementor> implementors =
map.get(operator);
if (implementors != null) {
if (implementors.size() == 1) {
return implementors.get(0).getValue();
} else {
for (Map.Entry<SqlOperator, RexCallImplementor> entry : implementors) {
if (operator == entry.getKey()) {
return entry.getValue();
}
}
}
}
}
return map.get(operator);
return null;
}

public @Nullable AggImplementor get(final SqlAggFunction aggregation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public class SqlBasicFunction extends SqlFunction {
* @param category Categorization for function
* @param monotonicityInference Strategy to infer monotonicity of a call
*/
protected SqlBasicFunction(String name, SqlKind kind, SqlSyntax syntax,
private SqlBasicFunction(String name, SqlKind kind, SqlSyntax syntax,
boolean deterministic, SqlReturnTypeInference returnTypeInference,
@Nullable SqlOperandTypeInference operandTypeInference,
SqlOperandHandler operandHandler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeTransforms;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlMonotonicity;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.util.Litmus;
Expand Down Expand Up @@ -577,10 +576,8 @@ static RelDataType deriveTypeSplit(SqlOperatorBinding operatorBinding,
* {@code rep} and returns modified value. */
@LibraryOperator(libraries = {REDSHIFT})
public static final SqlFunction REGEXP_REPLACE_2 =
new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
OperandHandlers.DEFAULT, OperandTypes.STRING_STRING, 0,
SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };
SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
OperandTypes.STRING_STRING, SqlFunctionCategory.STRING);

/** The "REGEXP_REPLACE(value, regexp, rep)"
* function. Replaces all substrings of value that match regexp with
Expand All @@ -596,11 +593,10 @@ static RelDataType deriveTypeSplit(SqlOperatorBinding operatorBinding,
* pos. */
@LibraryOperator(libraries = {MYSQL, ORACLE, REDSHIFT})
public static final SqlFunction REGEXP_REPLACE_4 =
new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
OperandHandlers.DEFAULT, OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING,
SqlTypeFamily.STRING, SqlTypeFamily.INTEGER),
0, SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };
SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING,
SqlTypeFamily.INTEGER),
SqlFunctionCategory.STRING);

/** The "REGEXP_REPLACE(value, regexp, rep, pos, [ occurrence | matchType ])"
* function. Replaces all substrings of value that match regexp with
Expand All @@ -609,27 +605,24 @@ static RelDataType deriveTypeSplit(SqlOperatorBinding operatorBinding,
* is a string of flags to apply to the search. */
@LibraryOperator(libraries = {MYSQL, REDSHIFT})
public static final SqlFunction REGEXP_REPLACE_5 =
new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
OperandHandlers.DEFAULT,
SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
OperandTypes.or(
OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING,
SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING,
SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.STRING)),
0, SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };
SqlFunctionCategory.STRING);

/** The "REGEXP_REPLACE(value, regexp, rep, pos, matchType)"
* function. Replaces all substrings of value that match regexp with
* {@code rep} and returns modified value. Start searching value from character position
* pos. Replace only the occurrence match or all matches if occurrence is 0. */
@LibraryOperator(libraries = {ORACLE})
public static final SqlFunction REGEXP_REPLACE_5_ORACLE =
new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
OperandHandlers.DEFAULT, OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING,
SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
0, SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };
SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING,
SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
SqlFunctionCategory.STRING);

/** The "REGEXP_REPLACE(value, regexp, rep, pos, occurrence, matchType)"
* function. Replaces all substrings of value that match regexp with
Expand All @@ -638,41 +631,34 @@ static RelDataType deriveTypeSplit(SqlOperatorBinding operatorBinding,
* is a string of flags to apply to the search. */
@LibraryOperator(libraries = {MYSQL, ORACLE, REDSHIFT})
public static final SqlFunction REGEXP_REPLACE_6 =
new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
OperandHandlers.DEFAULT, OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING,
SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.STRING),
0, SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };
SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING,
SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.STRING),
SqlFunctionCategory.STRING);

/** The "REGEXP_REPLACE(value, regexp, rep)"
* function. Replaces all substrings of value that match regexp with
* {@code rep} and returns modified value. */
@LibraryOperator(libraries = {BIG_QUERY})
public static final SqlFunction REGEXP_REPLACE_BIG_QUERY_3 =
new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
OperandHandlers.DEFAULT, OperandTypes.STRING_STRING_STRING, 0,
SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };
SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
OperandTypes.STRING_STRING_STRING, SqlFunctionCategory.STRING);

/** The "REGEXP_REPLACE(value, regexp, rep)"
* function. Replaces all substrings of value that match regexp with
* {@code rep} and returns modified value. */
@LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = REDSHIFT)
public static final SqlFunction REGEXP_REPLACE_PG_3 =
new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
OperandHandlers.DEFAULT, OperandTypes.STRING_STRING_STRING, 0,
SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };
SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
OperandTypes.STRING_STRING_STRING, SqlFunctionCategory.STRING);

/** The "REGEXP_REPLACE(value, regexp, rep, flags)"
* function. Replaces all substrings of value that match regexp with
* {@code rep} and returns modified value. flags are applied to the search. */
@LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = REDSHIFT)
public static final SqlFunction REGEXP_REPLACE_PG_4 =
new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
OperandHandlers.DEFAULT, OperandTypes.STRING_STRING_STRING_STRING, 0,
SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };
SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
OperandTypes.STRING_STRING_STRING_STRING, SqlFunctionCategory.STRING);

/** The "REGEXP_SUBSTR(value, regexp[, position[, occurrence]])" function.
* Returns the substring in value that matches the regexp. Returns NULL if there is no match. */
Expand Down Expand Up @@ -1875,10 +1861,8 @@ private static RelDataType deriveTypeMapFromEntries(SqlOperatorBinding opBinding
* converts {@code timestamp} to string according to the given {@code format}. */
@LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT})
public static final SqlFunction TO_CHAR_PG =
new SqlBasicFunction("TO_CHAR", SqlKind.OTHER_FUNCTION,
SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
OperandHandlers.DEFAULT, OperandTypes.TIMESTAMP_STRING, 0,
SqlFunctionCategory.TIMEDATE, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };
SqlBasicFunction.create("TO_CHAR", ReturnTypes.VARCHAR_NULLABLE,
OperandTypes.TIMESTAMP_STRING, SqlFunctionCategory.TIMEDATE);

/** The "TO_DATE(string1, string2)" function; casts string1
* to a DATE using the format specified in string2. */
Expand All @@ -1893,10 +1877,8 @@ private static RelDataType deriveTypeMapFromEntries(SqlOperatorBinding opBinding
* to a DATE using the format specified in string2. */
@LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT})
public static final SqlFunction TO_DATE_PG =
new SqlBasicFunction("TO_DATE", SqlKind.OTHER_FUNCTION,
SqlSyntax.FUNCTION, true, ReturnTypes.DATE_NULLABLE, null,
OperandHandlers.DEFAULT, OperandTypes.STRING_STRING, 0,
SqlFunctionCategory.TIMEDATE, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };
SqlBasicFunction.create("TO_DATE", ReturnTypes.DATE_NULLABLE,
OperandTypes.STRING_STRING, SqlFunctionCategory.TIMEDATE);

/** The "TO_TIMESTAMP(string1, string2)" function; casts string1
* to a TIMESTAMP using the format specified in string2. */
Expand All @@ -1911,10 +1893,8 @@ private static RelDataType deriveTypeMapFromEntries(SqlOperatorBinding opBinding
* to a TIMESTAMP using the format specified in string2. */
@LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT})
public static final SqlFunction TO_TIMESTAMP_PG =
new SqlBasicFunction("TO_TIMESTAMP", SqlKind.OTHER_FUNCTION,
SqlSyntax.FUNCTION, true, ReturnTypes.TIMESTAMP_TZ_NULLABLE, null,
OperandHandlers.DEFAULT, OperandTypes.STRING_STRING, 0,
SqlFunctionCategory.TIMEDATE, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };
SqlBasicFunction.create("TO_TIMESTAMP", ReturnTypes.TIMESTAMP_TZ_NULLABLE,
OperandTypes.STRING_STRING, SqlFunctionCategory.TIMEDATE);

/**
* The "PARSE_TIME(string, string)" function (BigQuery);
Expand Down Expand Up @@ -2512,10 +2492,8 @@ private static RelDataType deriveTypeMapFromEntries(SqlOperatorBinding opBinding
* to base numeric1.*/
@LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT})
public static final SqlFunction LOG_POSTGRES =
new SqlBasicFunction("LOG", SqlKind.LOG,
SqlSyntax.FUNCTION, true, ReturnTypes.DOUBLE_NULLABLE, null,
OperandHandlers.DEFAULT, OperandTypes.NUMERIC_OPTIONAL_NUMERIC, 0,
SqlFunctionCategory.NUMERIC, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };
SqlBasicFunction.create("LOG", ReturnTypes.DOUBLE_NULLABLE,
OperandTypes.NUMERIC_OPTIONAL_NUMERIC, SqlFunctionCategory.NUMERIC);

/** The "LOG2(numeric)" function. Returns the base 2 logarithm of numeric. */
@LibraryOperator(libraries = {MYSQL, SPARK})
Expand Down

0 comments on commit cc6ec96

Please sign in to comment.