Skip to content

Commit

Permalink
[SPARK-40760][SQL] Migrate type check failures of interval expression…
Browse files Browse the repository at this point in the history
…s onto error classes

### What changes were proposed in this pull request?
In the PR, I propose to add new error sub-classes of the error class `DATATYPE_MISMATCH`, and use it in the case of type check failures of some interval expressions.

### Why are the changes needed?
Migration onto error classes unifies Spark SQL error messages, and improves search-ability of errors.

### Does this PR introduce _any_ user-facing change?
Yes. The PR changes user-facing error messages.

### How was this patch tested?
By running the affected test suites:
```
$ build/sbt "test:testOnly *AnalysisSuite"
$ build/sbt "test:testOnly *ExpressionTypeCheckingSuite"
$ build/sbt "test:testOnly *ApproxCountDistinctForIntervalsSuite"
```

Closes apache#38237 from MaxGekk/type-check-fails-interval-exprs.

Authored-by: Max Gekk <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
MaxGekk authored and SandishKumarHN committed Dec 12, 2022
1 parent 7fe8d46 commit 1ccb642
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 43 deletions.
5 changes: 5 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@
"The <exprName> must be between <valueRange> (current value = <currentValue>)"
]
},
"WRONG_NUM_ENDPOINTS" : {
"message" : [
"The number of endpoints must be >= 2 to construct intervals but the actual number is <actualNumber>."
]
},
"WRONG_NUM_PARAMS" : {
"message" : [
"The <functionName> requires <expectedNum> parameters but the actual number is <actualNum>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ import java.util

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow}
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper}
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform

Expand All @@ -49,7 +50,10 @@ case class ApproxCountDistinctForIntervals(
relativeSD: Double = 0.05,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes with BinaryLike[Expression] {
extends TypedImperativeAggregate[Array[Long]]
with ExpectsInputTypes
with BinaryLike[Expression]
with QueryErrorsBase {

def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = {
this(
Expand Down Expand Up @@ -77,19 +81,32 @@ case class ApproxCountDistinctForIntervals(
if (defaultCheck.isFailure) {
defaultCheck
} else if (!endpointsExpression.foldable) {
TypeCheckFailure("The endpoints provided must be constant literals")
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "endpointsExpression",
"inputType" -> toSQLType(endpointsExpression.dataType)))
} else {
endpointsExpression.dataType match {
case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType |
_: AnsiIntervalType, _) =>
if (endpoints.length < 2) {
TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals")
DataTypeMismatch(
errorSubClass = "WRONG_NUM_ENDPOINTS",
messageParameters = Map("actualNumber" -> endpoints.length.toString))
} else {
TypeCheckSuccess
}
case _ =>
TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " +
"interval year to month or interval day to second) type")
case inputType =>
val requiredElemTypes = toSQLType(TypeCollection(
NumericType, DateType, TimestampType, TimestampNTZType, AnsiIntervalType))
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "2",
"requiredType" -> s"ARRAY OF $requiredElemTypes",
"inputSql" -> toSQLExpr(endpointsExpression),
"inputType" -> toSQLType(inputType)))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ case class Average(
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "average")
TypeUtils.checkForAnsiIntervalOrNumericType(child)

override def nullable: Boolean = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ case class Sum(
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, prettyName)
TypeUtils.checkForAnsiIntervalOrNumericType(child)

final override val nodePatterns: Seq[TreePattern] = Seq(SUM)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@ package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType
import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.types._

/**
* Functions to help with checking for valid data types and value comparison of various types.
*/
object TypeUtils {
object TypeUtils extends QueryErrorsBase {

def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
if (RowOrdering.isOrderable(dt)) {
Expand Down Expand Up @@ -70,13 +69,18 @@ object TypeUtils {
}
}

def checkForAnsiIntervalOrNumericType(
dt: DataType, funcName: String): TypeCheckResult = dt match {
def checkForAnsiIntervalOrNumericType(input: Expression): TypeCheckResult = input.dataType match {
case _: AnsiIntervalType | NullType =>
TypeCheckResult.TypeCheckSuccess
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
case other => TypeCheckResult.TypeCheckFailure(
s"function $funcName requires numeric or interval types, not ${other.catalogString}")
case other =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "1",
"requiredType" -> Seq(NumericType, AnsiIntervalType).map(toSQLType).mkString(" or "),
"inputSql" -> toSQLExpr(input),
"inputType" -> toSQLType(other)))
}

def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,12 @@ private[sql] abstract class DatetimeType extends AtomicType
* The interval type which conforms to the ANSI SQL standard.
*/
private[sql] abstract class AnsiIntervalType extends AtomicType

private[spark] object AnsiIntervalType extends AbstractDataType {
override private[sql] def simpleString: String = "ANSI interval"

override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[AnsiIntervalType]

override private[sql] def defaultConcreteType: DataType = DayTimeIntervalType()
}
Original file line number Diff line number Diff line change
Expand Up @@ -1163,25 +1163,39 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

test("SPARK-38118: Func(wrong_type) in the HAVING clause should throw data mismatch error") {
assertAnalysisError(parsePlan(
s"""
|WITH t as (SELECT true c)
|SELECT t.c
|FROM t
|GROUP BY t.c
|HAVING mean(t.c) > 0d""".stripMargin),
Seq(s"cannot resolve 'mean(t.c)' due to data type mismatch"),
false)
assertAnalysisErrorClass(
inputPlan = parsePlan(
s"""
|WITH t as (SELECT true c)
|SELECT t.c
|FROM t
|GROUP BY t.c
|HAVING mean(t.c) > 0d""".stripMargin),
expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
expectedMessageParameters = Map(
"sqlExpr" -> "\"mean(c)\"",
"paramIndex" -> "1",
"inputSql" -> "\"c\"",
"inputType" -> "\"BOOLEAN\"",
"requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""),
caseSensitive = false)

assertAnalysisError(parsePlan(
s"""
|WITH t as (SELECT true c, false d)
|SELECT (t.c AND t.d) c
|FROM t
|GROUP BY t.c, t.d
|HAVING mean(c) > 0d""".stripMargin),
Seq(s"cannot resolve 'mean(t.c)' due to data type mismatch"),
false)
assertAnalysisErrorClass(
inputPlan = parsePlan(
s"""
|WITH t as (SELECT true c, false d)
|SELECT (t.c AND t.d) c
|FROM t
|GROUP BY t.c, t.d
|HAVING mean(c) > 0d""".stripMargin),
expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
expectedMessageParameters = Map(
"sqlExpr" -> "\"mean(c)\"",
"paramIndex" -> "1",
"inputSql" -> "\"c\"",
"inputType" -> "\"BOOLEAN\"",
"requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""),
caseSensitive = false)

assertAnalysisErrorClass(
inputPlan = parsePlan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,29 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer
"dataType" -> "\"MAP<STRING, BIGINT>\""
)
)
assertError(Sum($"booleanField"), "function sum requires numeric or interval types")
assertError(Average($"booleanField"),
"function average requires numeric or interval types")

checkError(
exception = intercept[AnalysisException] {
assertSuccess(Sum($"booleanField"))
},
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
parameters = Map(
"sqlExpr" -> "\"sum(booleanField)\"",
"paramIndex" -> "1",
"inputSql" -> "\"booleanField\"",
"inputType" -> "\"BOOLEAN\"",
"requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""))
checkError(
exception = intercept[AnalysisException] {
assertSuccess(Average($"booleanField"))
},
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
parameters = Map(
"sqlExpr" -> "\"avg(booleanField)\"",
"paramIndex" -> "1",
"inputSql" -> "\"booleanField\"",
"inputType" -> "\"BOOLEAN\"",
"requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""))
}

test("check types for others") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.time.LocalDateTime

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, CreateArray, Literal, SpecificInternalRow}
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils}
import org.apache.spark.sql.types._
Expand All @@ -48,20 +48,31 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)())))
assert(wrongEndpoints.checkInputDataTypes() ==
TypeCheckFailure("The endpoints provided must be constant literals"))
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "endpointsExpression",
"inputType" -> "\"ARRAY<DOUBLE>\"")))

wrongEndpoints = ApproxCountDistinctForIntervals(
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Array(10L).map(Literal(_))))
assert(wrongEndpoints.checkInputDataTypes() ==
TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals"))
DataTypeMismatch("WRONG_NUM_ENDPOINTS", Map("actualNumber" -> "1")))

wrongEndpoints = ApproxCountDistinctForIntervals(
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Array("foobar").map(Literal(_))))
// scalastyle:off line.size.limit
assert(wrongEndpoints.checkInputDataTypes() ==
TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " +
"interval year to month or interval day to second) type"))
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "2",
"requiredType" -> "ARRAY OF (\"NUMERIC\" or \"DATE\" or \"TIMESTAMP\" or \"TIMESTAMP_NTZ\" or \"ANSI INTERVAL\")",
"inputSql" -> "\"array(foobar)\"",
"inputType" -> "\"ARRAY<STRING>\"")))
// scalastyle:on line.size.limit
}

/** Create an ApproxCountDistinctForIntervals instance and an input and output buffer. */
Expand Down

0 comments on commit 1ccb642

Please sign in to comment.