From 966e4fba5fcf5268c13232ff2d9ad09d8aee9b0c Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 13 May 2024 18:18:55 -0400 Subject: [PATCH] Address remaining feedbacks from #2312 (#2327) --- docs/spec.md | 168 ++--- docs/status.md | 4 +- stablehlo/dialect/Base.td | 3 + stablehlo/dialect/StablehloOps.td | 19 +- stablehlo/dialect/VhloOps.td | 2 + stablehlo/reference/Ops.cpp | 12 +- stablehlo/tests/interpret/convolution.mlir | 33 +- stablehlo/tests/interpret/dynamic_conv.mlir | 27 +- stablehlo/tests/verify_dynamic_conv.mlir | 646 ++++++------------ .../transforms/StablehloLegalizeToVhlo.cpp | 47 +- .../transforms/VhloLegalizeToStablehlo.cpp | 23 +- stablehlo/transforms/VhloToVersion.cpp | 13 +- stablehlo/transforms/VhloToVersionPatterns.td | 6 +- 13 files changed, 355 insertions(+), 648 deletions(-) diff --git a/docs/spec.md b/docs/spec.md index c8426716fa..84cff18b0c 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -2296,11 +2296,11 @@ For hybrid quantized types, performs `hybrid_dequantize_then_op( // ] // ]] // -// %rhs : [ -// [[[1]], [[1]], [[1]]], -// [[[1]], [[1]], [[1]]], -// [[[1]], [[1]], [[1]]] -// ] +// %rhs: [ +// [[[1]], [[1]], [[1]]], +// [[[1]], [[1]], [[1]]], +// [[[1]], [[1]], [[1]]] +// ] %result = "stablehlo.convolution"(%lhs, %rhs) { window_strides = array, padding = dense<0> : tensor<2x2xi64>, @@ -2706,79 +2706,9 @@ If not specified, all dimensions are assumed to be possibly expanding. #### Semantics -Computes dot products between windows of `lhs` and slices of `rhs` and produces -`result`. The following diagram shows how elements in `result` are computed from -`lhs` and `rhs` using a concrete example. - -![convolution](images/spec/convolution.svg) - -More formally, consider the following reframing of the inputs in terms of `lhs` -in order to be able to express windows of `lhs`. Additionally, padding is -specified dynamically via `d_padding`: - - -* `lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))`. -* `lhs_window_strides = lhs_shape(1, window_strides, 1)`. -* `lhs_padding = lhs_shape([0, 0], padding, [0, 0])`. -* `lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)`. -* `lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)`. - -This reframing uses the following helper functions: - -* `lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])`. -* `result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])`. -* `permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]` where `j[d] = i[permutation[d]]`. - -If `feature_group_count = 1` and `batch_group_count = 1`, then for all -`output_spatial_index` in `index_space(dim(result, output_spatial_dimensions...))`, -`result[result_shape(:, output_spatial_index, :)] = dot_product` where: - -* `padding_value = constant(0, element_type(lhs))`. -* `padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)`. -* `lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides`. -* `lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)`. -* `reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])`. - This feature appears to be unused, so in the future we are planning to remove - it ([#1181](https://github.com/openxla/stablehlo/issues/1181)). -* `dot_product = dot_general(reversed_lhs_window, rhs, - lhs_batching_dimensions=[], - lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], - rhs_batching_dimensions=[], - rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])`. - -If `feature_group_count > 1`: - -* `lhses = split(lhs, feature_group_count, input_feature_dimension)`. -* `rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)`. -* `results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)`. -* `result = concatenate(results, output_feature_dimension)`. - -If `batch_group_count > 1`: - -* `lhses = split(lhs, batch_group_count, input_batch_dimension)`. -* `rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)`. -* `results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)`. -* `result = concatenate(results, output_feature_dimension)`. - - -For quantized types, performs `dequantize_op_quantize( - lambda lhs, rhs: convolution(lhs, rhs, d_padding, window_strides, - lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, - input_feature_dimension, input_spatial_dimensions, - kernel_input_feature_dimension, kernel_output_feature_dimension, - kernel_spatial_dimensions, output_batch_dimension, - output_feature_dimension, output_spatial_dimensions, - feature_group_count, batch_group_count, precision_config), lhs, rhs, - type(result))`. - -For hybrid quantized types, performs `hybrid_dequantize_then_op( - lambda lhs, rhs: convolution(lhs, rhs, d_padding, window_strides, - lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, - input_feature_dimension, input_spatial_dimensions, - kernel_input_feature_dimension, kernel_output_feature_dimension, - kernel_spatial_dimensions, output_batch_dimension, - output_feature_dimension, output_spatial_dimensions, - feature_group_count, batch_group_count, precision_config), lhs, rhs)`. +This operation is functionally identical to +[convolution](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution) +op, but the padding is specified dynamically via `d_padding`. #### Inputs @@ -2786,7 +2716,7 @@ For hybrid quantized types, performs `hybrid_dequantize_then_op( |-------|-----------------------------------|--------------------------------------------------------------|-----------------------------------------------------------| | (I1) | `lhs` | tensor or per-tensor quantized tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) | | (I2) | `rhs` | tensor or quantized tensor | (C1), (C14-C16), (C26-C28), (C30-C33) | -| (I3) | `d_padding` | 2-dimensional tensor of type `si64` | (C4) | +| (I3) | `d_padding` | 2-dimensional tensor of integer type | (C4) | | (I4) | `window_strides` | 1-dimensional tensor constant of type `si64` | (C2-C3) | | (I5) | `lhs_dilation` | 1-dimensional tensor constant of type `si64` | (C5-C6) | | (I6) | `rhs_dilation` | 1-dimensional tensor constant of type `si64` | (C7-C8) | @@ -2846,22 +2776,34 @@ For hybrid quantized types, performs `hybrid_dequantize_then_op( * (C22) `0 < batch_group_count`. * (C23) `feature_group_count = 1 or batch_group_count = 1`. * (C24) `size(precision_config) = 2`. -* (C25) `rank(result) = N`. +* (C25) `dim(result, result_dim)` is defined as: + * `dim(lhs, input_batch_dimension) / batch_group_count` if `result_dim = output_batch_dimension`. + * `dim(rhs, kernel_output_feature_dimension)` if `result_dim = output_feature_dimension`. + * `num_windows` otherwise, where: + * `output_spatial_dimensions[spatial_dim] = result_dim`. + * `lhs_dim = input_spatial_dimensions[spatial_dim]`. + * `rhs_dim = kernel_spatial_dimensions[spatial_dim]`. + * `dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1`. + * `padded_input_shape[lhs_dim] = d_padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + d_padding[spatial_dim, 1]`. + * `dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1`. + * `is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]`. + * `num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1`. +* (C26) `rank(result) = N`. * If the operation uses non-quantized tensors: - * (C26) `element_type(lhs) = element_type(rhs) = element_type(result)`. + * (C27) `element_type(lhs) = element_type(rhs) = element_type(result)`. * If the operation uses quantized tensors: - * (C27) `is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)`. - * (C28) If `is_per_axis_quantized(rhs)`, + * (C28) `is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)`. + * (C29) If `is_per_axis_quantized(rhs)`, then `quantization_dimension(rhs) = kernel_output_feature_dimension`. - * (C29) If `is_per_axis_quantized(result)`, then + * (C30) If `is_per_axis_quantized(result)`, then `quantization_dimension(result) = output_feature_dimension`. * If `is_quantized(lhs)`: - * (C30) `storage_type(lhs) = storage_type(rhs)`. - * (C31) `expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)`. - * (C32) If `is_per_tensor_quantized(rhs)`, then + * (C31) `storage_type(lhs) = storage_type(rhs)`. + * (C32) `expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)`. + * (C33) If `is_per_tensor_quantized(rhs)`, then `is_per_tensor_quantized(result)`. * If `!is_quantized(lhs)`: - * (C33) `element_type(lhs) = expressed_type(rhs) = element_type(result)`. + * (C34) `element_type(lhs) = expressed_type(rhs) = element_type(result)`. #### Examples @@ -2874,30 +2816,36 @@ For hybrid quantized types, performs `hybrid_dequantize_then_op( // [[12], [13], [16], [17]] // ]] // -// %rhs : [ +// %rhs: [ // [[[1]], [[1]], [[1]]], // [[[1]], [[1]], [[1]]], // [[[1]], [[1]], [[1]]] // ] -%result = "stablehlo.convolution"(%lhs, %rhs) { +// %d_padding: [[1, 1], +// [1, 1]] +%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %d_padding) { window_strides = array, - padding = dense<0> : tensor<2x2xi64>, lhs_dilation = array, rhs_dilation = array, window_reversal = array, - // In the StableHLO dialect, dimension numbers are encoded via: - // `[]x[]->[output dimensions]`. - // "b" is batch dimension, "f" is feature dimension, - // "i" is input feature dimension, "o" is output feature dimension, - // "0/1/etc" are spatial dimensions. - dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, - batch_group_count = 1 : i64, + dimension_numbers = #stablehlo.conv, feature_group_count = 1 : i64, + batch_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] -} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> +} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64> // %result: [[ -// [[10], [26]], -// [[46], [62]] +// [[1], [5]], +// [[10], [14]] // ]] ``` @@ -3002,10 +2950,10 @@ op, but the result shape is specified dynamically via `output_shape`. #### Inputs -| Label | Name | Type | Constraints | -|-------|------------------|----------------------------------------------|-------------| -| (I1) | `output_shape` | 1-dimensional tensor of integer type | (C1), (C2) | -| (I2) | `iota_dimension` | `si64` | (C1) | +| Label | Name | Type | Constraints | +|-------|------------------|--------------------------------------|-------------| +| (I1) | `output_shape` | 1-dimensional tensor of integer type | (C1), (C2) | +| (I2) | `iota_dimension` | `si64` | (C1) | #### Outputs @@ -3105,10 +3053,10 @@ op, but the result shape is specified dynamically via `output_shape`. #### Inputs -| Label | Name | Type | Constraints | -|-------|----------------|----------------------------------------------|-------------| -| (I1) | `operand` | tensor or quantized tensor | (C1-C3) | -| (I2) | `output_shape` | 1-dimensional tensor of integer type | (C4) | +| Label | Name | Type | Constraints | +|-------|----------------|--------------------------------------|-------------| +| (I1) | `operand` | tensor or quantized tensor | (C1-C3) | +| (I2) | `output_shape` | 1-dimensional tensor of integer type | (C4) | #### Outputs @@ -3142,7 +3090,7 @@ op, but the result shape is specified dynamically via `output_shape`. ```mlir // %operand: [[1, 2, 3], [4, 5, 6]] -%output_shape = stablehlo.constant dense<[3, 2]> : tensor<2xi64> +// %output_shape: [3, 2] %result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64> // %result: [[1, 2], [3, 4], [5, 6]] ``` diff --git a/docs/status.md b/docs/status.md index a3053431b2..f1707ae75a 100644 --- a/docs/status.md +++ b/docs/status.md @@ -79,8 +79,8 @@ one of the following tracking labels. | dot | no | revisit | infeasible | yes | revisit | | dot_general | yes | revisit | infeasible | no | yes | | dynamic_broadcast_in_dim | yes | yes | infeasible | yes | revisit | -| dynamic_conv | yes | yes | infeasible | yes | revisit | -| dynamic_gather | yes | yes | infeasible | yes | revisit | +| dynamic_conv | yes | yes | infeasible | revisit | revisit | +| dynamic_gather | yes | yes | infeasible | no | revisit | | dynamic_iota | yes | yes | infeasible | yes | revisit | | dynamic_pad | yes | yes | infeasible | yes | revisit | | dynamic_reshape | yes | yes | infeasible | yes | revisit | diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index 2b31360401..5387956c2e 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -142,6 +142,9 @@ def HLO_Token : Type($_self)">, "token">; // Any integer tensor types def HLO_IntTensor : RankedTensorOf<[HLO_Int]>; +// Any integer tensor type with rank 2. +def HLO_2DIntTensor : TensorRankOf<[HLO_Int], [2]>; + // Any integer tensor type with rank 0 (i.e. representing a single integer). def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>; diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index ea44f3efac..905f548b19 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -3492,19 +3492,21 @@ def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", [HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect]> { let summary = "DynamicConv operation"; let description = [{ - Computes dot products between windows of `lhs` and slices of `rhs` and - produces `result`. + This operation is functionally identical to + [convolution](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution) + op, but the padding is specified dynamically via `d_padding`. Example: ```mlir + %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%lhs, %rhs, %d_padding) { window_strides = array, lhs_dilation = array, rhs_dilation = array, window_reversal = array, dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, - feature_group_count = 1 : i64, batch_group_count = 1 : i64, + feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] } : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64> ``` @@ -3513,7 +3515,7 @@ def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", let arguments = (ins HLO_Tensor:$lhs, /*dynamic_conv_i1*/ HLO_Tensor:$rhs, /*dynamic_conv_i2*/ - HLO_Tensor:$d_padding, /*dynamic_conv_i3*/ + HLO_2DIntTensor:$d_padding, /*dynamic_conv_i3*/ OptionalAttr:$window_strides, /*dynamic_conv_i4*/ OptionalAttr:$lhs_dilation, /*dynamic_conv_i5*/ OptionalAttr:$rhs_dilation, /*dynamic_conv_i6*/ @@ -3532,15 +3534,6 @@ def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", return reversal.has_value() && llvm::any_of(reversal.value(), [](bool v) { return v; }); } }]; - - let assemblyFormat = [{ - `(`operands`)` - `dim_numbers` `=` custom($dimension_numbers) `,` - `window` `=` `{` custom($window_strides, - $lhs_dilation, $rhs_dilation, - $window_reversal) `}` - attr-dict `:` functional-type(operands, results) - }]; } #endif // STABLEHLO_DIALECT_STABLEHLO_OPS diff --git a/stablehlo/dialect/VhloOps.td b/stablehlo/dialect/VhloOps.td index a550b2ca99..a05d2ce26e 100644 --- a/stablehlo/dialect/VhloOps.td +++ b/stablehlo/dialect/VhloOps.td @@ -453,6 +453,8 @@ def VHLO_DynamicConvOpV1 : VHLO_Op<"dynamic_conv_v1", "0.9.0", "0.19.0"> { let results = (outs VHLO_AnyType:$result); } +// `d_padding` should be used instead of `padding` for dynamic convolution, so +// `padding` is removed for clarity. def VHLO_DynamicConvOpV2 : VHLO_Op<"dynamic_conv_v2", "0.20.0", "current"> { let arguments = (ins VHLO_AnyType:$lhs, diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index e992cf0935..14a51b2523 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -55,6 +55,13 @@ Index evalIndex(Tensor tensor) { return result; } +template +SmallVector extractAttributeOrDefault(std::optional> attr, + int64_t size, T value) { + if (attr.has_value()) return llvm::to_vector(attr.value()); + return SmallVector(size, value); +} + Tensor dotGeneralOp(const Tensor &lhs, const Tensor &rhs, const Axes &lhsContractingDimensions, const Axes &rhsContractingDimensions) { @@ -519,9 +526,8 @@ SmallVector eval(Region ®ion, auto rhs = scope.findTensor(op.getRhs()); auto rank = lhs.getRank(); - SmallVector windowStrides(rank - 2, 1); - if (auto windowStridesAttr = op.getWindowStrides()) - windowStrides = SmallVector(windowStridesAttr.value()); + SmallVector windowStrides = extractAttributeOrDefault( + op.getWindowStrides(), rank - 2, 1); SmallVector> padding(rank - 2, {0, 0}); if (auto paddingAttr = op.getPaddingAttr()) { diff --git a/stablehlo/tests/interpret/convolution.mlir b/stablehlo/tests/interpret/convolution.mlir index 4b7fdc3ad0..bee430e3cf 100644 --- a/stablehlo/tests/interpret/convolution.mlir +++ b/stablehlo/tests/interpret/convolution.mlir @@ -20,8 +20,37 @@ func.func @convolution_op_test_si64() { } : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> check.expect_eq_const %result, dense<[[ - [[10], [26]], - [[46], [62]] + [[10], [26]], + [[46], [62]] + ]]> : tensor<1x2x2x1xi64> + func.return +} + +// ----- + +func.func @convolution_op_test_padding() { + %lhs = stablehlo.constant dense<[[ + [[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[10], [11], [14], [15]], + [[12], [13], [16], [17]] + ]]> : tensor<1x4x4x1xi64> + %rhs = stablehlo.constant dense<1> : tensor<3x3x1x1xi64> + %result = stablehlo.convolution(%lhs, %rhs) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [4, 4], + pad = [[1, 1], [1, 1]], + lhs_dilate = [2, 2] + } { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } + : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> + check.expect_eq_const %result, dense<[[ + [[1], [5]], + [[10], [14]] ]]> : tensor<1x2x2x1xi64> func.return } diff --git a/stablehlo/tests/interpret/dynamic_conv.mlir b/stablehlo/tests/interpret/dynamic_conv.mlir index a4fe213c4b..0e7d1d2c3c 100644 --- a/stablehlo/tests/interpret/dynamic_conv.mlir +++ b/stablehlo/tests/interpret/dynamic_conv.mlir @@ -8,21 +8,20 @@ func.func @dynamic_conv_op_test_si64() { [[12], [13], [16], [17]] ]]> : tensor<1x4x4x1xi64> %rhs = stablehlo.constant dense<1> : tensor<3x3x1x1xi64> - %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %result = stablehlo.dynamic_conv(%lhs, %rhs, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { - stride = [4, 4], - lhs_dilate = [2, 2] - } { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } - : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64> + %d_padding = stablehlo.constant dense<1> : tensor<2x2xi64> + %result = "stablehlo.dynamic_conv"(%lhs, %rhs, %d_padding) { + window_strides = array, + lhs_dilation = array, + rhs_dilation = array, + window_reversal = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64> check.expect_eq_const %result, dense<[[ - [[10], [26]], - [[46], [62]] + [[1], [5]], + [[10], [14]] ]]> : tensor<1x2x2x1xi64> func.return } diff --git a/stablehlo/tests/verify_dynamic_conv.mlir b/stablehlo/tests/verify_dynamic_conv.mlir index 908a27828c..02845087b0 100644 --- a/stablehlo/tests/verify_dynamic_conv.mlir +++ b/stablehlo/tests/verify_dynamic_conv.mlir @@ -5,24 +5,11 @@ func.func @dynamic_conv(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x3 tensor<100x28x28x1xf32> { %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<2x2xi64>) -> - tensor<100x28x28x1xf32> + tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> } @@ -30,24 +17,18 @@ func.func @dynamic_conv(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x3 // CHECK: func @dynamic_conv_empty_spatial_dimensions // CHECK: stablehlo.dynamic_conv -// CHECK-SAME: dim_numbers = [b, f]x[i, o]->[b, f] -// CHECK-SAME: window = {stride = [], lhs_dilate = [], -// CHECK-SAME: rhs_dilate = [], reverse = []} +// CHECK: batch_group_count = 1 : i64, +// CHECK: dimension_numbers = #stablehlo.conv<[b, f]x[i, o]->[b, f]>, +// CHECK: feature_group_count = 1 : i64 func.func @dynamic_conv_empty_spatial_dimensions(%arg0: tensor<3x2xf16>, - %arg1: tensor<2x2xf16>) -> tuple> { + %arg1: tensor<2x2xf16>) -> tensor<3x2xf16> { %d_padding = stablehlo.constant dense<0> : tensor<0x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, f]x[i, o]->[b, f], - window = {stride = [], lhs_dilate = [], rhs_dilate = [], - reverse = []} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } - : (tensor<3x2xf16>, tensor<2x2xf16>, tensor<0x2xi64>) -> tensor<3x2xf16> - %1 = "stablehlo.tuple"(%0) : (tensor<3x2xf16>) -> tuple> - func.return %1 : tuple> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, f]x[i, o]->[b, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<3x2xf16>, tensor<2x2xf16>, tensor<0x2xi64>) -> tensor<3x2xf16> + func.return %result : tensor<3x2xf16> } // ----- @@ -57,82 +38,27 @@ func.func @dynamic_conv_upcast(%arg0 : tensor<100x26x26x32xi8>, %arg1 : tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> { %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array - } : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>, tensor<2x2xi64>) -> tensor<100x28x28x1xi32> + batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>, tensor<2x2xi64>) -> + tensor<100x28x28x1xi32> func.return %result : tensor<100x28x28x1xi32> } // ----- -func.func @dynamic_conv(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { - // expected-error@+4{{Unexpected keyword stide}} - %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stide = [2, 1], rhs_dilate = [1, 2]} - { batch_group_count = 1 : i64, feature_group_count = 1 : i64} - : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>, tensor<2x2xi64>) -> tensor<3x5x5x4xf32> - func.return %0 : tensor<3x5x5x4xf32> -} - -// ----- - -func.func @dynamic_conv(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { - // expected-error@+4{{expected integer value}} - %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [2, b], rhs_dilate = [1, 2]} - { batch_group_count = 1 : i64, feature_group_count = 1 : i64} - : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>, tensor<2x2xi64>) -> tensor<3x5x5x4xf32> - func.return %0 : tensor<3x5x5x4xf32> -} - -// ----- - -func.func @dynamic_conv(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { - // expected-error@+4{{Unexpected keyword stride}} - %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [2, 1], rhs_dilate = [1, 2], stride=[2,1]} - { batch_group_count = 1 : i64, feature_group_count = 1 : i64} - : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>, tensor<2x2xi64>) -> tensor<3x5x5x4xf32> - func.return %0 : tensor<3x5x5x4xf32> -} - -// ----- - func.func @dynamic_conv_c1(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects convolution arguments to have same number of dimensions. Got: 'tensor<1x8x8x207xf32>' and 'tensor<3x3x207xf32>'.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -141,17 +67,14 @@ func.func @dynamic_conv_c2(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects window-strides to have same dimension-size as size of window dimensions (2), but got: 1.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + window_strides = array, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -160,44 +83,14 @@ func.func @dynamic_conv_c3(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects window to have positive stride for 1-th window dimension, but got 0.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 0], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> -} - -// ----- - -func.func @dynamic_conv_c4_i3_invalid_d_padding_rank(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> - tensor<100x28x28x1xf32> { - // expected-error@+2 {{expects d_padding to be of rank 2 but got 3}} - %d_padding = stablehlo.constant dense<2> : tensor<2x2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + window_strides = array, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array - } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<2x2x2xi64>) -> - tensor<100x28x28x1xf32> - func.return %result : tensor<100x28x28x1xf32> + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -207,24 +100,11 @@ func.func @dynamic_conv_c4_invalid_d_padding_dim_0(%arg0 : tensor<100x26x26x32xf // expected-error@+2 {{expects d_padding to be of shape [2, 2], but got [3, 2]}} %d_padding = stablehlo.constant dense<2> : tensor<3x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<3x2xi64>) -> - tensor<100x28x28x1xf32> + tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> } @@ -235,24 +115,11 @@ func.func @dynamic_conv_c4_invalid_d_padding_dim_1(%arg0 : tensor<100x26x26x32xf // expected-error@+2 {{expects d_padding to be of shape [2, 2], but got [2, 3]}} %d_padding = stablehlo.constant dense<2> : tensor<2x3xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<2x3xi64>) -> - tensor<100x28x28x1xf32> + tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> } @@ -261,24 +128,11 @@ func.func @dynamic_conv_c4_invalid_d_padding_dim_1(%arg0 : tensor<100x26x26x32xf func.func @dynamic_conv_c4_dynamic_d_padding(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>, %arg2 : tensor) -> tensor<100x28x28x1xf32> { %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor) -> - tensor<100x28x28x1xf32> + tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> } @@ -288,16 +142,14 @@ func.func @dynamic_conv_c5(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects base-dilation factors to have same dimension-size as size of window dimensions (2), but got: 1.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + lhs_dilation = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -306,17 +158,14 @@ func.func @dynamic_conv_c6(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects window to have positive base dilation factor for 0-th window dimension, but got 0.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [0, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + lhs_dilation = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -325,16 +174,14 @@ func.func @dynamic_conv_c7(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects window-dilation factors to have same dimension-size as size of window dimensions (2), but got: 1.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + rhs_dilation = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -343,17 +190,14 @@ func.func @dynamic_conv_c8(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects window to have positive window dilation factor for 0-th window dimension, but got 0.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [0, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + rhs_dilation = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -362,16 +206,14 @@ func.func @dynamic_conv_c9(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects window-reversal to have same dimension-size as size of window dimensions (2), but got: 1.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + window_reversal = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -380,17 +222,13 @@ func.func @dynamic_conv_c10(%arg0: tensor<5x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects input batch dimension (5) to be divisible by batch_group_count. Got batch_group_count = 2.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 2 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<5x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 2 : i64 + } : (tensor<5x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -399,17 +237,13 @@ func.func @dynamic_conv_c11(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects input feature dimension (207) to be a multiple of feature_group_count. Got feature_group_count = 2.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 2 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 2 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -417,7 +251,7 @@ func.func @dynamic_conv_c11(%arg0: tensor<1x8x8x207xf32>, func.func @dynamic_conv_c12(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> { // expected-error@+2{{expects convolution arguments to have 4 dimensions. Got: 5}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) {batch_group_count = 1 : i64, + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { dimension_numbers = #stablehlo.conv, %arg1: tensor<3x3x3 output_batch_dimension = 1, output_feature_dimension = 4, output_spatial_dimensions = [2, 3] - >, feature_group_count = 1 : i64, lhs_dilation = array, precision_config = [#stablehlo, #stablehlo], rhs_dilation = array, window_strides = array} : - (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>, tensor<2x2xi64>) -> tensor<32x1x8x8x16xf32> - func.return %0 : tensor<32x1x8x8x16xf32> + >, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>, tensor<2x2xi64>) -> + tensor<32x1x8x8x16xf32> + func.return %result : tensor<32x1x8x8x16xf32> } // ----- @@ -439,17 +276,12 @@ func.func @dynamic_conv_c13(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects convolution arguments to have >= 2 dimensions. Got: 'tensor<1xf32>' and 'tensor<3xf32>'.}} %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1xf32>, tensor<3xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1xf32>, tensor<3xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -459,7 +291,6 @@ func.func @dynamic_conv_c13(%arg0 : tensor<100x26x26x32xf32>, // expected-error@+2 {{expects input dimension-numbers to be unique, got {0, 0, 1, 2}.}} %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, dimension_numbers = #stablehlo.conv, output_spatial_dimensions = [1, 2] >, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<2x2xi64>) -> tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> @@ -487,7 +316,6 @@ func.func @dynamic_conv_c13(%arg0 : tensor<100x26x26x32xf32>, // expected-error@+2 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, dimension_numbers = #stablehlo.conv, output_spatial_dimensions = [1, 2] >, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<2x2xi64>) -> tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> @@ -515,7 +341,6 @@ func.func @dynamic_conv_c13(%arg0 : tensor<100x26x26x32xf32>, // expected-error@+2 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, dimension_numbers = #stablehlo.conv, output_spatial_dimensions = [1, 2] >, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<2x2xi64>) -> tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> @@ -542,17 +365,13 @@ func.func @dynamic_conv_c14(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects input feature dimension (207) / feature_group_count = kernel input feature dimension (20). Got feature_group_count = 1.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -561,17 +380,13 @@ func.func @dynamic_conv_c15(%arg0: tensor<3x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<3x8x8x16xf32> { // expected-error@+2 {{expects output feature dimension size (16) to be a multiple of batch_group_count. Got batch_group_count = 3.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 3 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<3x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<3x8x8x16xf32> - func.return %0 : tensor<3x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 3 : i64 + } : (tensor<3x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<3x8x8x16xf32> + func.return %result : tensor<3x8x8x16xf32> } // ----- @@ -580,17 +395,13 @@ func.func @dynamic_conv_c16(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x69x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects kernel output feature dimension (16) to be divisible by feature_group_count. For feature_group_count = 3.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 3 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x69x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 3 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x69x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -599,17 +410,13 @@ func.func @dynamic_conv_c17(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects the same size for input, kernel and output spatial-dimensions, but got 2, 3, and 2 resp.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, 2, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, 2, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -619,7 +426,6 @@ func.func @dynamic_conv_c18(%arg0 : tensor<100x26x26x32xf32>, // expected-error@+2 {{expects kernel dimension-numbers to be unique, got {3, 2, 0, 0}.}} %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, dimension_numbers = #stablehlo.conv, output_spatial_dimensions = [1, 2] >, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<2x2xi64>) -> - tensor<100x28x28x1xf32> + tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> } @@ -647,7 +451,6 @@ func.func @dynamic_conv_c18(%arg0 : tensor<100x26x26x32xf32>, // expected-error@+2 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, dimension_numbers = #stablehlo.conv, output_spatial_dimensions = [1, 2] >, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<2x2xi64>) -> - tensor<100x28x28x1xf32> + tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> } @@ -675,7 +476,6 @@ func.func @dynamic_conv_c18(%arg0 : tensor<100x26x26x32xf32>, // expected-error@+2 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, dimension_numbers = #stablehlo.conv, output_spatial_dimensions = [1, 2] >, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<2x2xi64>) -> tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> @@ -702,17 +500,13 @@ func.func @dynamic_conv_c19(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects the same size for input, kernel and output spatial-dimensions, but got 2, 2, and 3 resp.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, 2, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, 2, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -722,7 +516,6 @@ func.func @dynamic_conv_c20(%arg0 : tensor<100x26x26x32xf32>, // expected-error@+2 {{expects output dimension-numbers to be unique, got {0, 3, 0, 3}.}} %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, dimension_numbers = #stablehlo.conv, output_spatial_dimensions = [0, 3] >, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<2x2xi64>) -> - tensor<100x28x28x1xf32> + tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> } @@ -750,7 +541,6 @@ func.func @dynamic_conv_c20(%arg0 : tensor<100x26x26x32xf32>, // expected-error@+2 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, dimension_numbers = #stablehlo.conv, output_spatial_dimensions = [1, 2] >, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<2x2xi64>) -> - tensor<100x28x28x1xf32> + tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> } @@ -778,7 +566,6 @@ func.func @dynamic_conv_c20(%arg0 : tensor<100x26x26x32xf32>, // expected-error@+2 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} %d_padding = stablehlo.constant dense<2> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { - batch_group_count = 1 : i64, dimension_numbers = #stablehlo.conv, output_spatial_dimensions = [1, 2] >, feature_group_count = 1 : i64, - lhs_dilation = array, - rhs_dilation = array, - window_strides = array + batch_group_count = 1 : i64 } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>, tensor<2x2xi64>) -> - tensor<100x28x28x1xf32> + tensor<100x28x28x1xf32> func.return %result : tensor<100x28x28x1xf32> } @@ -805,17 +590,13 @@ func.func @dynamic_conv_c21(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects feature_group_count to be a positive number, got 0.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 0 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 0 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -824,17 +605,13 @@ func.func @dynamic_conv_c22(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects batch_group_count to be a positive number, got 0.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 0 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 0 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- @@ -843,37 +620,27 @@ func.func @dynamic_conv_c23(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects batch_group_count and feature_group_count not to be both greater than 1. Got 2 and 2 resp.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 2 : i64, - feature_group_count = 2 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 2 : i64, + batch_group_count = 2 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } // ----- -func.func @dynamic_conv_c24(%arg0: tensor<3x2xf16>, - %arg1: tensor<2x2xf16>) -> tuple> { +func.func @dynamic_conv_c24(%arg0: tensor<3x2xf16>, %arg1: tensor<2x2xf16>) -> tensor<3x2xf16> { // expected-error@+2{{expects precision config to be empty or have <= 2 elements}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, f]x[i, o]->[b, f], - window = {stride = [], lhs_dilate = [], rhs_dilate = [], - reverse = []} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo, #stablehlo] - } - : (tensor<3x2xf16>, tensor<2x2xf16>, tensor<2x2xi64>) -> tensor<3x2xf16> - %1 = "stablehlo.tuple"(%0) : (tensor<3x2xf16>) -> tuple> - func.return %1 : tuple> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, f]x[i, o]->[b, f]>, + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo, #stablehlo] + } : (tensor<3x2xf16>, tensor<2x2xf16>, tensor<2x2xi64>) -> tensor<3x2xf16> + func.return %result : tensor<3x2xf16> } // ----- @@ -882,17 +649,15 @@ func.func @dynamic_conv_c27(%arg0: tensor<1x4x4x1xi64>, %arg1: tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi64> { // expected-error@+2 {{expects lhs and rhs to have compatible element type. Got: 'i64' and 'i32'}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { window_strides = array, lhs_dilation = array, - rhs_dilation = array, - window_reversal = array, dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 1 : i64, - batch_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi32>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64> - func.return %0 : tensor<1x2x2x1xi64> + batch_group_count = 1 : i64 + } : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi32>, tensor<2x2xi64>) -> + tensor<1x2x2x1xi64> + func.return %result : tensor<1x2x2x1xi64> } // ----- @@ -901,14 +666,11 @@ func.func @dynamic_conv_invalid_window_attributes(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+2 {{expects window to have positive value for 0-th window dimension, but got 0.}} %d_padding = stablehlo.constant dense<0> : tensor<2x2xi64> - %0 = stablehlo.dynamic_conv(%arg0, %arg1, %d_padding) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<0x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %d_padding) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<0x3x207x16xf32>, tensor<2x2xi64>) -> + tensor<1x8x8x16xf32> + func.return %result : tensor<1x8x8x16xf32> } diff --git a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp index fcd78d5f1a..277a61e93c 100644 --- a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +++ b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp @@ -623,47 +623,24 @@ LogicalResult addDefaults(const OpConversionPattern& pattern, if (!stablehloOp.getCompositeAttributesAttr()) addDefaultAttr("composite_attributes", builder.getDictionaryAttr({})); } - if constexpr (std::is_same::value) { - auto numSpatialDimensions = static_cast( - stablehloOp.getDimensionNumbers().getInputSpatialDimensions().size()); - if (!stablehloOp.getWindowStridesAttr()) - addDefaultAttr("window_strides", - builder.getDenseI64ArrayAttr( - SmallVector(numSpatialDimensions, 1ll))); - if (!stablehloOp.getPaddingAttr()) - addDefaultAttr("padding", - DenseIntElementsAttr::get( - RankedTensorType::get({numSpatialDimensions, 2}, - builder.getI64Type()), - SmallVector(numSpatialDimensions * 2, 0ll))); - if (!stablehloOp.getLhsDilationAttr()) - addDefaultAttr("lhs_dilation", - builder.getDenseI64ArrayAttr( - SmallVector(numSpatialDimensions, 1ll))); - if (!stablehloOp.getRhsDilationAttr()) - addDefaultAttr("rhs_dilation", - builder.getDenseI64ArrayAttr( - SmallVector(numSpatialDimensions, 1ll))); - if (!stablehloOp.getWindowReversalAttr()) - addDefaultAttr("window_reversal", - DenseIntElementsAttr::get( - RankedTensorType::get({numSpatialDimensions}, - builder.getI1Type()), - SmallVector(numSpatialDimensions, false))); - if (!stablehloOp.getPrecisionConfigAttr()) - addDefaultAttr( - "precision_config", - builder.getArrayAttr(SmallVector( - 2, stablehlo::PrecisionAttr::get( - pattern.getContext(), stablehlo::Precision::DEFAULT)))); - } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value || + std::is_same::value) { auto numSpatialDimensions = static_cast( stablehloOp.getDimensionNumbers().getInputSpatialDimensions().size()); if (!stablehloOp.getWindowStridesAttr()) addDefaultAttr("window_strides", builder.getDenseI64ArrayAttr( SmallVector(numSpatialDimensions, 1ll))); + if constexpr (std::is_same::value) { + if (!stablehloOp.getPaddingAttr()) + addDefaultAttr( + "padding", + DenseIntElementsAttr::get( + RankedTensorType::get({numSpatialDimensions, 2}, + builder.getI64Type()), + SmallVector(numSpatialDimensions * 2, 0ll))); + } if (!stablehloOp.getLhsDilationAttr()) addDefaultAttr("lhs_dilation", builder.getDenseI64ArrayAttr( diff --git a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp index 712ea02d43..3318aec2b0 100644 --- a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp @@ -700,25 +700,14 @@ LogicalResult removeDefaults(const OpConversionPattern& pattern, eraseAttrs(vhloAttrs, "composite_attributes"); } } - if constexpr (std::is_same::value) { - if (isSplatTensor(pattern, vhloOp.getWindowStridesAttr(), 1ll)) - eraseAttrs(vhloAttrs, "window_strides"); - if (isSplatTensor(pattern, vhloOp.getPaddingAttr(), 0ll)) - eraseAttrs(vhloAttrs, "padding"); - if (isSplatTensor(pattern, vhloOp.getLhsDilationAttr(), 1ll)) - eraseAttrs(vhloAttrs, "lhs_dilation"); - if (isSplatTensor(pattern, vhloOp.getRhsDilationAttr(), 1ll)) - eraseAttrs(vhloAttrs, "rhs_dilation"); - if (isSplatTensor(pattern, vhloOp.getWindowReversalAttr(), false)) - eraseAttrs(vhloAttrs, "window_reversal"); - if (isSplatArray(vhloOp.getPrecisionConfigAttr(), - vhlo::PrecisionV1Attr::get(pattern.getContext(), - vhlo::PrecisionV1::DEFAULT))) - eraseAttrs(vhloAttrs, "precision_config"); - } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value || + std::is_same::value) { if (isSplatTensor(pattern, vhloOp.getWindowStridesAttr(), 1ll)) eraseAttrs(vhloAttrs, "window_strides"); + if constexpr (std::is_same::value) { + if (isSplatTensor(pattern, vhloOp.getPaddingAttr(), 0ll)) + eraseAttrs(vhloAttrs, "padding"); + } if (isSplatTensor(pattern, vhloOp.getLhsDilationAttr(), 1ll)) eraseAttrs(vhloAttrs, "lhs_dilation"); if (isSplatTensor(pattern, vhloOp.getRhsDilationAttr(), 1ll)) diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp index 649acf91dd..d538da56d6 100644 --- a/stablehlo/transforms/VhloToVersion.cpp +++ b/stablehlo/transforms/VhloToVersion.cpp @@ -271,19 +271,20 @@ struct VhloToVersionPass : public VhloToVersionPassBase { /// Upgrade and Downgrade Definitions /// ///////////////////////////////////////// -TensorV1Attr getDefaultPadding(OpBuilder& builder, Value lhs) { +TensorV1Attr getDefaultConvPadding(OpBuilder& builder, Value lhs) { auto lhsType = dyn_cast(lhs.getType()); if (!lhsType) return TensorV1Attr(); // Convert to DenseElements for getRawData handling. - int64_t rankMinusTwo = lhsType.getShape().size() - 2; + SmallVector paddingShape{ + static_cast(lhsType.getShape().size() - 2), 2}; auto denseElements = DenseIntElementsAttr::get( - RankedTensorType::get({rankMinusTwo, 2}, builder.getI64Type()), - SmallVector(rankMinusTwo * 2, 0ll)); + RankedTensorType::get(paddingShape, builder.getI64Type()), + SmallVector(paddingShape[0] * 2, 0ll)); return TensorV1Attr::get( builder.getContext(), - RankedTensorV1Type::get(builder.getContext(), {rankMinusTwo, 2}, + RankedTensorV1Type::get(builder.getContext(), paddingShape, IntegerSI64V1Type::get(builder.getContext()), nullptr), denseElements.getRawData()); @@ -298,8 +299,6 @@ namespace stablehlo { void populateVhloToVersionPatterns(RewritePatternSet* patterns, TypeConverter* converter, MLIRContext* context) { - // Currently empty because we're starting from a clean slate in v0.9.0 and - // changes so far are additive. vhlo::populateWithGenerated(*patterns); } diff --git a/stablehlo/transforms/VhloToVersionPatterns.td b/stablehlo/transforms/VhloToVersionPatterns.td index c8684a61db..f3836f418a 100644 --- a/stablehlo/transforms/VhloToVersionPatterns.td +++ b/stablehlo/transforms/VhloToVersionPatterns.td @@ -16,12 +16,12 @@ limitations under the License. include "mlir/IR/OpBase.td" include "stablehlo/dialect/VhloOps.td" -def VHLO_GetDefaultPadding : NativeCodeCall<"getDefaultPadding($_builder, $0)">; +def VHLO_GetDefaultConvPadding : NativeCodeCall<"getDefaultConvPadding($_builder, $0)">; def DynamicConvUpgradeV1ToV2: Pat<(VHLO_DynamicConvOpV1 $lhs, $rhs, $d_padding, $window_strides, $padding, $lhs_dilation, $rhs_dilation, $window_reversal, $input_batch_dimension, $input_feature_dimension, $input_spatial_dimensions, $kernel_input_feature_dimension, $kernel_output_feature_dimension, $kernel_spatial_dimensions, $output_batch_dimension, $output_feature_dimension, $output_spatial_dimensions, $feature_group_count, $batch_group_count, $precision_config), (VHLO_DynamicConvOpV2 $lhs, $rhs, $d_padding, $window_strides, $lhs_dilation, $rhs_dilation, $window_reversal, $input_batch_dimension, $input_feature_dimension, $input_spatial_dimensions, $kernel_input_feature_dimension, $kernel_output_feature_dimension, $kernel_spatial_dimensions, $output_batch_dimension, $output_feature_dimension, $output_spatial_dimensions, $feature_group_count, $batch_group_count, $precision_config)>; def DynamicConvDowngradeV2ToV1: - Pat<(VHLO_DynamicConvOpV2:$result $lhs, $rhs, $d_padding, $window_strides, $lhs_dilation, $rhs_dilation, $window_reversal, $input_batch_dimension, $input_feature_dimension, $input_spatial_dimensions, $kernel_input_feature_dimension, $kernel_output_feature_dimension, $kernel_spatial_dimensions, $output_batch_dimension, $output_feature_dimension, $output_spatial_dimensions, $feature_group_count, $batch_group_count, $precision_config), - (VHLO_DynamicConvOpV1 $lhs, $rhs, $d_padding, $window_strides, (VHLO_GetDefaultPadding $lhs), $lhs_dilation, $rhs_dilation, $window_reversal, $input_batch_dimension, $input_feature_dimension, $input_spatial_dimensions, $kernel_input_feature_dimension, $kernel_output_feature_dimension, $kernel_spatial_dimensions, $output_batch_dimension, $output_feature_dimension, $output_spatial_dimensions, $feature_group_count, $batch_group_count, $precision_config)>; + Pat<(VHLO_DynamicConvOpV2 $lhs, $rhs, $d_padding, $window_strides, $lhs_dilation, $rhs_dilation, $window_reversal, $input_batch_dimension, $input_feature_dimension, $input_spatial_dimensions, $kernel_input_feature_dimension, $kernel_output_feature_dimension, $kernel_spatial_dimensions, $output_batch_dimension, $output_feature_dimension, $output_spatial_dimensions, $feature_group_count, $batch_group_count, $precision_config), + (VHLO_DynamicConvOpV1 $lhs, $rhs, $d_padding, $window_strides, (VHLO_GetDefaultConvPadding $lhs), $lhs_dilation, $rhs_dilation, $window_reversal, $input_batch_dimension, $input_feature_dimension, $input_spatial_dimensions, $kernel_input_feature_dimension, $kernel_output_feature_dimension, $kernel_spatial_dimensions, $output_batch_dimension, $output_feature_dimension, $output_spatial_dimensions, $feature_group_count, $batch_group_count, $precision_config)>;