diff --git a/core/trino-main/src/main/java/io/trino/json/ir/SqlJsonLiteralConverter.java b/core/trino-main/src/main/java/io/trino/json/ir/SqlJsonLiteralConverter.java index d665ca744ada..9b12b83b3eeb 100644 --- a/core/trino-main/src/main/java/io/trino/json/ir/SqlJsonLiteralConverter.java +++ b/core/trino-main/src/main/java/io/trino/json/ir/SqlJsonLiteralConverter.java @@ -80,43 +80,35 @@ public static Optional getTextTypedValue(JsonNode jsonNode) public static Optional getNumericTypedValue(JsonNode jsonNode) { if (jsonNode.getNodeType() == JsonNodeType.NUMBER) { - if (jsonNode instanceof BigIntegerNode) { - if (jsonNode.canConvertToInt()) { - return Optional.of(new TypedValue(INTEGER, jsonNode.longValue())); + return switch (jsonNode) { + case BigIntegerNode _ -> { + if (jsonNode.canConvertToInt()) { + yield Optional.of(new TypedValue(INTEGER, jsonNode.longValue())); + } + if (jsonNode.canConvertToLong()) { + yield Optional.of(new TypedValue(BIGINT, jsonNode.longValue())); + } + throw new JsonLiteralConversionException(jsonNode, "value too big"); } - if (jsonNode.canConvertToLong()) { - return Optional.of(new TypedValue(BIGINT, jsonNode.longValue())); + case DecimalNode _ -> { + BigDecimal jsonDecimal = jsonNode.decimalValue(); + int precision = jsonDecimal.precision(); + if (precision > MAX_PRECISION) { + throw new JsonLiteralConversionException(jsonNode, "precision too big"); + } + int scale = jsonDecimal.scale(); + DecimalType decimalType = createDecimalType(precision, scale); + Object value = decimalType.isShort() ? encodeShortScaledValue(jsonDecimal, scale) : encodeScaledValue(jsonDecimal, scale); + yield Optional.of(TypedValue.fromValueAsObject(decimalType, value)); } - throw new JsonLiteralConversionException(jsonNode, "value too big"); - } - if (jsonNode instanceof DecimalNode) { - BigDecimal jsonDecimal = jsonNode.decimalValue(); - int precision = jsonDecimal.precision(); - if (precision > MAX_PRECISION) { - throw new JsonLiteralConversionException(jsonNode, "precision too big"); - } - int scale = jsonDecimal.scale(); - DecimalType decimalType = createDecimalType(precision, scale); - Object value = decimalType.isShort() ? encodeShortScaledValue(jsonDecimal, scale) : encodeScaledValue(jsonDecimal, scale); - return Optional.of(TypedValue.fromValueAsObject(decimalType, value)); - } - if (jsonNode instanceof DoubleNode) { - return Optional.of(new TypedValue(DOUBLE, jsonNode.doubleValue())); - } - if (jsonNode instanceof FloatNode) { - return Optional.of(new TypedValue(REAL, floatToRawIntBits(jsonNode.floatValue()))); - } - if (jsonNode instanceof IntNode) { - return Optional.of(new TypedValue(INTEGER, jsonNode.longValue())); - } - if (jsonNode instanceof LongNode) { - return Optional.of(new TypedValue(BIGINT, jsonNode.longValue())); - } - if (jsonNode instanceof ShortNode) { - return Optional.of(new TypedValue(SMALLINT, jsonNode.longValue())); - } + case DoubleNode _ -> Optional.of(new TypedValue(DOUBLE, jsonNode.doubleValue())); + case FloatNode _ -> Optional.of(new TypedValue(REAL, floatToRawIntBits(jsonNode.floatValue()))); + case IntNode _ -> Optional.of(new TypedValue(INTEGER, jsonNode.longValue())); + case LongNode _ -> Optional.of(new TypedValue(BIGINT, jsonNode.longValue())); + case ShortNode _ -> Optional.of(new TypedValue(SMALLINT, jsonNode.longValue())); + default -> Optional.empty(); + }; } - return Optional.empty(); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java index 5e7666d97cde..093a600a9864 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java @@ -1241,16 +1241,12 @@ boolean isPrimitiveType() public BytecodeExpression initialValueExpression() { - if (initialValue == null) { - return defaultValue(type); - } - if (initialValue instanceof Number) { - return constantNumber((Number) initialValue); - } - if (initialValue instanceof Boolean) { - return constantBoolean((boolean) initialValue); - } - throw new IllegalArgumentException("Unsupported initial value type: " + initialValue.getClass()); + return switch (initialValue) { + case null -> defaultValue(type); + case Number number -> constantNumber(number); + case Boolean _ -> constantBoolean((boolean) initialValue); + default -> throw new IllegalArgumentException("Unsupported initial value type: " + initialValue.getClass()); + }; } } } diff --git a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java index ffe1882dc0b9..50036f6391be 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java +++ b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java @@ -82,22 +82,19 @@ private void validateAudience(Claims claims) } Object tokenAudience = claims.get(AUDIENCE); - if (tokenAudience == null) { - throw new InvalidClaimException(format("Expected %s claim to be: %s, but was not present in the JWT claims.", AUDIENCE, requiredAudience.get())); - } - - if (tokenAudience instanceof String) { - if (!requiredAudience.get().equals((String) tokenAudience)) { - throw new InvalidClaimException(format("Invalid Audience: %s. Allowed audiences: %s", tokenAudience, requiredAudience.get())); + switch (tokenAudience) { + case String value -> { + if (!requiredAudience.get().equals(value)) { + throw new InvalidClaimException(format("Invalid Audience: %s. Allowed audiences: %s", tokenAudience, requiredAudience.get())); + } } - } - else if (tokenAudience instanceof Collection) { - if (((Collection) tokenAudience).stream().map(String.class::cast).noneMatch(aud -> requiredAudience.get().equals(aud))) { - throw new InvalidClaimException(format("Invalid Audience: %s. Allowed audiences: %s", tokenAudience, requiredAudience.get())); + case Collection collection -> { + if (collection.stream().noneMatch(aud -> requiredAudience.get().equals(aud))) { + throw new InvalidClaimException(format("Invalid Audience: %s. Allowed audiences: %s", tokenAudience, requiredAudience.get())); + } } - } - else { - throw new InvalidClaimException(format("Invalid Audience: %s", tokenAudience)); + case null -> throw new InvalidClaimException(format("Expected %s claim to be: %s, but was not present in the JWT claims.", AUDIENCE, requiredAudience.get())); + default -> throw new InvalidClaimException(format("Invalid Audience: %s", tokenAudience)); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 6f7cc0422026..7f119f553f5e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -1903,19 +1903,12 @@ private ArgumentsAnalysis analyzeArguments(List argumentS private ArgumentAnalysis analyzeArgument(ArgumentSpecification argumentSpecification, TableFunctionArgument argument, Optional scope) { - String actualType; - if (argument.getValue() instanceof TableFunctionTableArgument) { - actualType = "table"; - } - else if (argument.getValue() instanceof TableFunctionDescriptorArgument) { - actualType = "descriptor"; - } - else if (argument.getValue() instanceof Expression) { - actualType = "expression"; - } - else { - throw semanticException(INVALID_FUNCTION_ARGUMENT, argument, "Unexpected table function argument type: %s", argument.getClass().getSimpleName()); - } + String actualType = switch (argument.getValue()) { + case TableFunctionTableArgument _ -> "table"; + case TableFunctionDescriptorArgument _ -> "descriptor"; + case Expression _ -> "expression"; + default -> throw semanticException(INVALID_FUNCTION_ARGUMENT, argument, "Unexpected table function argument type: %s", argument.getClass().getSimpleName()); + }; if (argumentSpecification instanceof TableArgumentSpecification) { if (!(argument.getValue() instanceof TableFunctionTableArgument)) { @@ -4023,66 +4016,66 @@ private void analyzeJsonTableColumns( JsonTable jsonTable) { for (JsonTableColumnDefinition column : columns) { - if (column instanceof OrdinalityColumn ordinalityColumn) { - String name = ordinalityColumn.getName().getCanonicalValue(); - if (!uniqueNames.add(name)) { - throw semanticException(DUPLICATE_COLUMN_OR_PATH_NAME, ordinalityColumn.getName(), "All column and path names in JSON_TABLE invocation must be unique"); - } - outputFields.add(Field.newUnqualified(name, BIGINT)); - orderedOutputColumns.add(NodeRef.of(ordinalityColumn)); - } - else if (column instanceof ValueColumn valueColumn) { - String name = valueColumn.getName().getCanonicalValue(); - if (!uniqueNames.add(name)) { - throw semanticException(DUPLICATE_COLUMN_OR_PATH_NAME, valueColumn.getName(), "All column and path names in JSON_TABLE invocation must be unique"); + switch (column) { + case OrdinalityColumn ordinalityColumn -> { + String name = ordinalityColumn.getName().getCanonicalValue(); + if (!uniqueNames.add(name)) { + throw semanticException(DUPLICATE_COLUMN_OR_PATH_NAME, ordinalityColumn.getName(), "All column and path names in JSON_TABLE invocation must be unique"); + } + outputFields.add(Field.newUnqualified(name, BIGINT)); + orderedOutputColumns.add(NodeRef.of(ordinalityColumn)); } - valueColumn.getEmptyDefault().ifPresent(expression -> verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, expression, "default expression for JSON_TABLE column")); - valueColumn.getErrorDefault().ifPresent(expression -> verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, expression, "default expression for JSON_TABLE column")); - JsonPathAnalysis pathAnalysis = valueColumn.getJsonPath() - .map(this::analyzeJsonPath) - .orElseGet(() -> analyzeImplicitJsonPath(getImplicitJsonPath(name), valueColumn.getLocation())); - analysis.setJsonPathAnalysis(valueColumn, pathAnalysis); - TypeAndAnalysis typeAndAnalysis = analyzeJsonValueExpression( - valueColumn, - pathAnalysis, - session, - plannerContext, - statementAnalyzerFactory, - accessControl, - enclosingScope, - analysis, - warningCollector, - correlationSupport); - // default values can contain subqueries - the subqueries are recorded under the enclosing JsonTable node - analysis.recordSubqueries(jsonTable, typeAndAnalysis.analysis()); - outputFields.add(Field.newUnqualified(name, typeAndAnalysis.type())); - orderedOutputColumns.add(NodeRef.of(valueColumn)); - } - else if (column instanceof QueryColumn queryColumn) { - String name = queryColumn.getName().getCanonicalValue(); - if (!uniqueNames.add(name)) { - throw semanticException(DUPLICATE_COLUMN_OR_PATH_NAME, queryColumn.getName(), "All column and path names in JSON_TABLE invocation must be unique"); + case ValueColumn valueColumn -> { + String name = valueColumn.getName().getCanonicalValue(); + if (!uniqueNames.add(name)) { + throw semanticException(DUPLICATE_COLUMN_OR_PATH_NAME, valueColumn.getName(), "All column and path names in JSON_TABLE invocation must be unique"); + } + valueColumn.getEmptyDefault().ifPresent(expression -> verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, expression, "default expression for JSON_TABLE column")); + valueColumn.getErrorDefault().ifPresent(expression -> verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, expression, "default expression for JSON_TABLE column")); + JsonPathAnalysis pathAnalysis = valueColumn.getJsonPath() + .map(this::analyzeJsonPath) + .orElseGet(() -> analyzeImplicitJsonPath(getImplicitJsonPath(name), valueColumn.getLocation())); + analysis.setJsonPathAnalysis(valueColumn, pathAnalysis); + TypeAndAnalysis typeAndAnalysis = analyzeJsonValueExpression( + valueColumn, + pathAnalysis, + session, + plannerContext, + statementAnalyzerFactory, + accessControl, + enclosingScope, + analysis, + warningCollector, + correlationSupport); + // default values can contain subqueries - the subqueries are recorded under the enclosing JsonTable node + analysis.recordSubqueries(jsonTable, typeAndAnalysis.analysis()); + outputFields.add(Field.newUnqualified(name, typeAndAnalysis.type())); + orderedOutputColumns.add(NodeRef.of(valueColumn)); } - JsonPathAnalysis pathAnalysis = queryColumn.getJsonPath() - .map(this::analyzeJsonPath) - .orElseGet(() -> analyzeImplicitJsonPath(getImplicitJsonPath(name), queryColumn.getLocation())); - analysis.setJsonPathAnalysis(queryColumn, pathAnalysis); - Type type = analyzeJsonQueryExpression(queryColumn, session, plannerContext, statementAnalyzerFactory, accessControl, enclosingScope, analysis, warningCollector); - outputFields.add(Field.newUnqualified(name, type)); - orderedOutputColumns.add(NodeRef.of(queryColumn)); - } - else if (column instanceof NestedColumns nestedColumns) { - nestedColumns.getPathName().ifPresent(name -> { - if (!uniqueNames.add(name.getCanonicalValue())) { - throw semanticException(DUPLICATE_COLUMN_OR_PATH_NAME, name, "All column and path names in JSON_TABLE invocation must be unique"); + case QueryColumn queryColumn -> { + String name = queryColumn.getName().getCanonicalValue(); + if (!uniqueNames.add(name)) { + throw semanticException(DUPLICATE_COLUMN_OR_PATH_NAME, queryColumn.getName(), "All column and path names in JSON_TABLE invocation must be unique"); } - }); - JsonPathAnalysis pathAnalysis = analyzeJsonPath(nestedColumns.getJsonPath()); - analysis.setJsonPathAnalysis(nestedColumns, pathAnalysis); - analyzeJsonTableColumns(nestedColumns.getColumns(), uniqueNames, outputFields, orderedOutputColumns, enclosingScope, jsonTable); - } - else { - throw new IllegalArgumentException("unexpected type of JSON_TABLE column: " + column.getClass().getSimpleName()); + JsonPathAnalysis pathAnalysis = queryColumn.getJsonPath() + .map(this::analyzeJsonPath) + .orElseGet(() -> analyzeImplicitJsonPath(getImplicitJsonPath(name), queryColumn.getLocation())); + analysis.setJsonPathAnalysis(queryColumn, pathAnalysis); + Type type = analyzeJsonQueryExpression(queryColumn, session, plannerContext, statementAnalyzerFactory, accessControl, enclosingScope, analysis, warningCollector); + outputFields.add(Field.newUnqualified(name, type)); + orderedOutputColumns.add(NodeRef.of(queryColumn)); + } + case NestedColumns nestedColumns -> { + nestedColumns.getPathName().ifPresent(name -> { + if (!uniqueNames.add(name.getCanonicalValue())) { + throw semanticException(DUPLICATE_COLUMN_OR_PATH_NAME, name, "All column and path names in JSON_TABLE invocation must be unique"); + } + }); + JsonPathAnalysis pathAnalysis = analyzeJsonPath(nestedColumns.getJsonPath()); + analysis.setJsonPathAnalysis(nestedColumns, pathAnalysis); + analyzeJsonTableColumns(nestedColumns.getColumns(), uniqueNames, outputFields, orderedOutputColumns, enclosingScope, jsonTable); + } + default -> throw new IllegalArgumentException("unexpected type of JSON_TABLE column: " + column.getClass().getSimpleName()); } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationMerge.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationMerge.java index 262cf4879d54..b484c8950e67 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationMerge.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationMerge.java @@ -78,16 +78,12 @@ public Optional mergeFirstSource() addOriginalMappings(source, i, newMappingsBuilder); } - if (node instanceof UnionNode) { - return Optional.of(new UnionNode(node.getId(), newSources, newMappingsBuilder.build(), node.getOutputSymbols())); - } - if (node instanceof IntersectNode) { - return Optional.of(new IntersectNode(node.getId(), newSources, newMappingsBuilder.build(), node.getOutputSymbols(), mergedQuantifier.get())); - } - if (node instanceof ExceptNode) { - return Optional.of(new ExceptNode(node.getId(), newSources, newMappingsBuilder.build(), node.getOutputSymbols(), mergedQuantifier.get())); - } - throw new IllegalArgumentException("unexpected node type: " + node.getClass().getSimpleName()); + return switch (node) { + case UnionNode _ -> Optional.of(new UnionNode(node.getId(), newSources, newMappingsBuilder.build(), node.getOutputSymbols())); + case IntersectNode _ -> Optional.of(new IntersectNode(node.getId(), newSources, newMappingsBuilder.build(), node.getOutputSymbols(), mergedQuantifier.get())); + case ExceptNode _ -> Optional.of(new ExceptNode(node.getId(), newSources, newMappingsBuilder.build(), node.getOutputSymbols(), mergedQuantifier.get())); + default -> throw new IllegalArgumentException("unexpected node type: " + node.getClass().getSimpleName()); + }; } /** diff --git a/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java b/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java index cbcf5bd0f448..490a2df6413e 100644 --- a/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java +++ b/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java @@ -302,30 +302,20 @@ private static MaterializedRow convertToTestTypes(MaterializedRow trinoRow) List convertedValues = new ArrayList<>(); for (int field = 0; field < trinoRow.getFieldCount(); field++) { Object trinoValue = trinoRow.getField(field); - Object convertedValue; - if (trinoValue instanceof SqlDate) { - convertedValue = LocalDate.ofEpochDay(((SqlDate) trinoValue).getDays()); - } - else if (trinoValue instanceof SqlTime) { - convertedValue = DateTimeFormatter.ISO_LOCAL_TIME.parse(trinoValue.toString(), LocalTime::from); - } - else if (trinoValue instanceof SqlTimeWithTimeZone) { - long nanos = roundDiv(((SqlTimeWithTimeZone) trinoValue).getPicos(), PICOSECONDS_PER_NANOSECOND); - int offsetMinutes = ((SqlTimeWithTimeZone) trinoValue).getOffsetMinutes(); - convertedValue = OffsetTime.of(LocalTime.ofNanoOfDay(nanos), ZoneOffset.ofTotalSeconds(offsetMinutes * 60)); - } - else if (trinoValue instanceof SqlTimestamp) { - convertedValue = ((SqlTimestamp) trinoValue).toLocalDateTime(); - } - else if (trinoValue instanceof SqlTimestampWithTimeZone) { - convertedValue = ((SqlTimestampWithTimeZone) trinoValue).toZonedDateTime(); - } - else if (trinoValue instanceof SqlDecimal) { - convertedValue = ((SqlDecimal) trinoValue).toBigDecimal(); - } - else { - convertedValue = trinoValue; - } + Object convertedValue = switch (trinoValue) { + case null -> null; + case SqlDate sqlDate -> LocalDate.ofEpochDay(sqlDate.getDays()); + case SqlTime _ -> DateTimeFormatter.ISO_LOCAL_TIME.parse(trinoValue.toString(), LocalTime::from); + case SqlTimeWithTimeZone sqlTimeWithTimeZone -> { + long nanos = roundDiv(sqlTimeWithTimeZone.getPicos(), PICOSECONDS_PER_NANOSECOND); + int offsetMinutes = sqlTimeWithTimeZone.getOffsetMinutes(); + yield OffsetTime.of(LocalTime.ofNanoOfDay(nanos), ZoneOffset.ofTotalSeconds(offsetMinutes * 60)); + } + case SqlTimestamp sqlTimestamp -> sqlTimestamp.toLocalDateTime(); + case SqlTimestampWithTimeZone sqlTimestampWithTimeZone -> sqlTimestampWithTimeZone.toZonedDateTime(); + case SqlDecimal sqlDecimal -> sqlDecimal.toBigDecimal(); + default -> trinoValue; + }; convertedValues.add(convertedValue); } return new MaterializedRow(trinoRow.getPrecision(), convertedValues); diff --git a/core/trino-main/src/main/java/io/trino/util/Failures.java b/core/trino-main/src/main/java/io/trino/util/Failures.java index 477b019f765b..f509236080b2 100644 --- a/core/trino-main/src/main/java/io/trino/util/Failures.java +++ b/core/trino-main/src/main/java/io/trino/util/Failures.java @@ -155,18 +155,12 @@ private static ErrorLocation getErrorLocation(Throwable throwable) @Nullable private static ErrorCode toErrorCode(Throwable throwable) { - requireNonNull(throwable); - - if (throwable instanceof TrinoException trinoException) { - return trinoException.getErrorCode(); - } - if (throwable instanceof Failure failure) { - return failure.getFailureInfo().getErrorCode(); - } - if (throwable instanceof ParsingException) { - return SYNTAX_ERROR.toErrorCode(); - } - return null; + return switch (requireNonNull(throwable)) { + case TrinoException trinoException -> trinoException.getErrorCode(); + case Failure failure -> failure.getFailureInfo().getErrorCode(); + case ParsingException _ -> SYNTAX_ERROR.toErrorCode(); + default -> null; + }; } public static TrinoException internalError(Throwable t) diff --git a/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java b/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java index b7b8a681e72c..d3232adae807 100644 --- a/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java +++ b/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java @@ -621,34 +621,16 @@ public static RowBlock createRowBlock(List fieldTypes, Object[]... rows) for (int fieldIndex = 0; fieldIndex < fieldTypes.size(); fieldIndex++) { Type fieldType = fieldTypes.get(fieldIndex); Object fieldValue = row[fieldIndex]; - if (fieldValue == null) { - fieldBuilders.get(fieldIndex).appendNull(); - continue; - } - - if (fieldValue instanceof String) { - fieldType.writeSlice(fieldBuilders.get(fieldIndex), utf8Slice((String) fieldValue)); - } - else if (fieldValue instanceof Slice) { - fieldType.writeSlice(fieldBuilders.get(fieldIndex), (Slice) fieldValue); - } - else if (fieldValue instanceof Double) { - fieldType.writeDouble(fieldBuilders.get(fieldIndex), (Double) fieldValue); - } - else if (fieldValue instanceof Long) { - fieldType.writeLong(fieldBuilders.get(fieldIndex), (Long) fieldValue); - } - else if (fieldValue instanceof Boolean) { - fieldType.writeBoolean(fieldBuilders.get(fieldIndex), (Boolean) fieldValue); - } - else if (fieldValue instanceof Block) { - fieldType.writeObject(fieldBuilders.get(fieldIndex), fieldValue); - } - else if (fieldValue instanceof Integer) { - fieldType.writeLong(fieldBuilders.get(fieldIndex), (Integer) fieldValue); - } - else { - throw new IllegalArgumentException(); + switch (fieldValue) { + case null -> fieldBuilders.get(fieldIndex).appendNull(); + case String s -> fieldType.writeSlice(fieldBuilders.get(fieldIndex), utf8Slice(s)); + case Slice slice -> fieldType.writeSlice(fieldBuilders.get(fieldIndex), slice); + case Double v -> fieldType.writeDouble(fieldBuilders.get(fieldIndex), v); + case Long l -> fieldType.writeLong(fieldBuilders.get(fieldIndex), l); + case Boolean b -> fieldType.writeBoolean(fieldBuilders.get(fieldIndex), b); + case Block _ -> fieldType.writeObject(fieldBuilders.get(fieldIndex), fieldValue); + case Integer i -> fieldType.writeLong(fieldBuilders.get(fieldIndex), i); + default -> throw new IllegalArgumentException(); } } }); diff --git a/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java b/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java index 2c57e5259851..7ded8207bedd 100644 --- a/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java +++ b/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java @@ -646,25 +646,27 @@ private static Object getNonNullValueForType(Type type) if (type.getJavaType() == LongTimestampWithTimeZone.class) { return LongTimestampWithTimeZone.fromEpochSecondsAndFraction(1, 0, UTC_KEY); } - if (type instanceof ArrayType arrayType) { - Type elementType = arrayType.getElementType(); - Object elementNonNullValue = getNonNullValueForType(elementType); - return arrayBlockOf(elementType, elementNonNullValue); - } - if (type instanceof MapType mapType) { - Type keyType = mapType.getKeyType(); - Type valueType = mapType.getValueType(); - Object keyNonNullValue = getNonNullValueForType(keyType); - Object valueNonNullValue = getNonNullValueForType(valueType); - Map map = ImmutableMap.of(keyNonNullValue, valueNonNullValue); - return sqlMapOf(keyType, valueType, map); - } - if (type instanceof RowType rowType) { - List elementTypes = rowType.getTypeParameters(); - Object[] elementNonNullValues = elementTypes.stream().map(AbstractTestType::getNonNullValueForType).toArray(Object[]::new); - return toRow(elementTypes, elementNonNullValues); - } - throw new IllegalStateException("Unsupported Java type " + type.getJavaType() + " (for type " + type + ")"); + switch (type) { + case ArrayType arrayType -> { + Type elementType = arrayType.getElementType(); + Object elementNonNullValue = getNonNullValueForType(elementType); + return arrayBlockOf(elementType, elementNonNullValue); + } + case MapType mapType -> { + Type keyType = mapType.getKeyType(); + Type valueType = mapType.getValueType(); + Object keyNonNullValue = getNonNullValueForType(keyType); + Object valueNonNullValue = getNonNullValueForType(valueType); + Map map = ImmutableMap.of(keyNonNullValue, valueNonNullValue); + return sqlMapOf(keyType, valueType, map); + } + case RowType rowType -> { + List elementTypes = rowType.getTypeParameters(); + Object[] elementNonNullValues = elementTypes.stream().map(AbstractTestType::getNonNullValueForType).toArray(Object[]::new); + return toRow(elementTypes, elementNonNullValues); + } + default -> throw new IllegalStateException("Unsupported Java type " + type.getJavaType() + " (for type " + type + ")"); + } } private Block toBlock(Object value)