Skip to content

Commit

Permalink
Spec dynamic_pad (openxla#2306)
Browse files Browse the repository at this point in the history
This includes:
- Adding partial support to the interpreter (only when the result is
static)
- Cleaning up and completing the verifier
- Updating the ODS
- Adding relevant tests

I made a few drive-by fixes to related logic/ops as well.

Constraint coverage:

```
I1. operand is tensor or per-tensor quantized tensor: ODS
I2. padding_value is 0-dimensional tensor or per-tensor quantized tensor: ODS
I3. edge_padding_low is integer tensor: ODS
I4. edge_padding_high is integer tensor: ODS
I5. interior_padding is integer tensor: ODS
C1. element_type(operand) = element_type(padding_value) = element_type(result): ODS
C2. size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand): ODS + verifier
C3. 0 <= interior_padding: verifier
C4. shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high: verifier
```
ref openxla#2267
Fixes openxla#2292
  • Loading branch information
mlevesquedion authored May 9, 2024
1 parent 73c32a3 commit aa576d6
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 93 deletions.
67 changes: 63 additions & 4 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ in StableHLO programs. In the meanwhile, here is the list of these operations:
`trace` ([#604](https://github.com/openxla/stablehlo/issues/604)).
* "Dynamism" category of StableHLO operations - they were bootstrapped from
MHLO,and we are in the process of speccing them: `dynamic_broadcast_in_dim`,
`dynamic_conv`, `dynamic_gather`, `dynamic_pad`, `real_dynamic_slice`,
`set_dimension_size`.
`dynamic_conv`, `dynamic_gather`, `real_dynamic_slice`, `set_dimension_size`.
([#8](https://github.com/openxla/stablehlo/issues/8)).
* Shape computations, including `arith`, `shape` and `tensor` operations
([#8](https://github.com/openxla/stablehlo/issues/8)).
Expand Down Expand Up @@ -2648,7 +2647,7 @@ op, but the result shape is specified dynamically via `output_shape`.

| Label | Name | Type | Constraints |
|-------|------------------|----------------------------------------------|-------------|
| (I1) | `output_shape` | 1-dimensional tensor constant of type `si64` | (C1), (C2) |
| (I1) | `output_shape` | 1-dimensional tensor of integer type | (C1), (C2) |
| (I2) | `iota_dimension` | `si64` | (C1) |

#### Outputs
Expand Down Expand Up @@ -2680,6 +2679,66 @@ op, but the result shape is specified dynamically via `output_shape`.

&nbsp;[More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/dynamic_iota.mlir)

### dynamic_pad

#### Semantics

This operation is functionally identical to
[pad](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad)
op, but with `edge_padding_low`, `edge_padding_high` and `interior_padding`
specified dynamically as values.

#### Inputs

| Label | Name | Type | Constraints |
|-------|---------------------|-----------------------------------------------------|------------------|
| (I1) | `operand` | tensor or per-tensor quantized tensor | (C1), (C2), (C4) |
| (I2) | `padding_value` | 0-dimensional tensor or per-tensor quantized tensor | (C1) |
| (I3) | `edge_padding_low` | 1-dimensional tensor of integer type | (C1), (C4) |
| (I4) | `edge_padding_high` | 1-dimensional tensor of integer type | (C1), (C4) |
| (I5) | `interior_padding` | 1-dimensional tensor of integer type | (C2-C4) |

#### Outputs

| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C3-C6) |

#### Constraints

* (C1) `element_type(operand) = element_type(padding_value) =
element_type(result)`.
* (C2) `size(edge_padding_low) = size(edge_padding_high) =
size(interior_padding) = rank(operand)`.
* (C3) `0 <= interior_padding`.
* (C4) `shape(result) = shape(operand) + edge_padding_low +
max(shape(operand) - 1, 0) * interior_padding + edge_padding_high`.

#### Examples

```mlir
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi32>, tensor<i32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
```

&nbsp;[More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/dynamic_pad.mlir)

### dynamic_reshape

#### Semantics
Expand All @@ -2693,7 +2752,7 @@ op, but the result shape is specified dynamically via `output_shape`.
| Label | Name | Type | Constraints |
|-------|----------------|----------------------------------------------|-------------|
| (I1) | `operand` | tensor or quantized tensor | (C1-C3) |
| (I2) | `output_shape` | 1-dimensional tensor constant of type `si64` | (C4) |
| (I2) | `output_shape` | 1-dimensional tensor of integer type | (C4) |

#### Outputs

Expand Down
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ one of the following tracking labels.
| dynamic_conv | no | revisit | no | no | no |
| dynamic_gather | no | revisit | revisit | no | no |
| dynamic_iota | yes | yes | infeasible | yes | revisit |
| dynamic_pad | no | revisit | no | yes | no |
| dynamic_pad | yes | yes | infeasible | yes | revisit |
| dynamic_reshape | yes | yes | infeasible | yes | revisit |
| dynamic_slice | yes | yes | yes | yes | yes |
| dynamic_update_slice | yes | yes | yes | yes | yes |
Expand Down
2 changes: 2 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def HLO_PredTensor : RankedTensorOf<[HLO_Pred]>;

def HLO_Tensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>;

def HLO_ScalarTensor: 0DTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>;

def HLO_NonQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex]>;

def HLO_TensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
Expand Down
38 changes: 19 additions & 19 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2947,7 +2947,7 @@ def StableHLO_PadOp: StableHLO_ShapedInterfaceOp<"pad",
}];
let arguments = (ins
HLO_Tensor:$operand /*pad_i1*/,
HLO_Tensor:$padding_value /*pad_i2*/,
HLO_ScalarTensor:$padding_value /*pad_i2*/,
GenericDenseI64ArrayAttr:$edge_padding_low /*pad_i3*/,
GenericDenseI64ArrayAttr:$edge_padding_high /*pad_i4*/,
GenericDenseI64ArrayAttr:$interior_padding /*pad_i5*/
Expand Down Expand Up @@ -3416,37 +3416,37 @@ def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp<

def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad",
[ConditionallySpeculatable, NoMemoryEffect,
AllElementTypesMatch<["operand", "padding_value", "result"]>,
AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> {
AllElementTypesMatch<["operand", "padding_value", "result"]> /*dynamic_pad_c1*/,
AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]> /*dynamic_pad_c2*/,
AllRanksMatch<["operand", "result"]> /*dynamic_pad_c4*/]> {
let summary = "DynamicPad operation";
let description = [{
This operation is a work in progress, so it is not yet included in
the StableHLO specification: https://github.com/openxla/stablehlo/issues/8.
This operation is functionally identical to
[pad](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad)
https://github.com/openxla/stablehlo/pull/2306#discussion_r1595669709
op, but with `edge_padding_low`, `edge_padding_high` and `interior_padding`
specified dynamically as values.

Informally, this operation does the same thing as PadOp except
that `edge_padding_low`, `edge_padding_high` and `interior_padding` are
specified dynamically:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_pad

Example:
```mlir
%edge_padding_low = stablehlo.constant dense<[0, 1]> : tensor<2xi32>
%edge_padding_high = stablehlo.constant dense<[2, 1]> : tensor<2xi32>
%interior_padding = stablehlo.constant dense<[1, 2]> : tensor<2xi32>
%result = stablehlo.dynamic_pad %operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
: (tensor<?x?xf32>, tensor<f32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xf32>
: (tensor<2x3xi32>, tensor<i32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<5x9xi32>
```
}];
let arguments = (ins
HLO_Tensor:$operand,
HLO_Tensor:$padding_value,
HLO_DimensionTensor:$edge_padding_low,
HLO_DimensionTensor:$edge_padding_high,
HLO_DimensionTensor:$interior_padding
HLO_Tensor:$operand /*dynamic_pad_i1*/,
HLO_ScalarTensor:$padding_value /*dynamic_pad_i2*/,
HLO_StaticDimensionTensor:$edge_padding_low /*dynamic_pad_i3*/,
HLO_StaticDimensionTensor:$edge_padding_high /*dynamic_pad_i4*/,
HLO_StaticDimensionTensor:$interior_padding /*dynamic_pad_i5*/
);
let results = (outs HLO_Tensor:$result);
let description = [{
Dynamically Pads the `operand`, with amount of padding added at
low-end/high-end/interior is passed through input tensors.
}];
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
Expand Down
69 changes: 37 additions & 32 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2677,14 +2677,6 @@ LogicalResult inferPadOp(std::optional<Location> location, Type operandType,
ArrayRef<int64_t> interiorPadding,
SmallVectorImpl<Type>& inferredReturnTypes) {
auto inputType = cast<RankedTensorType>(operandType);
auto padType = cast<RankedTensorType>(paddingValueType);

// pad_i2
if (padType.getRank() != 0)
return emitOptionalError(location,
"padding value type should be a rank-0 "
"tensor, is rank ",
padType.getRank());

int64_t rank = inputType.getRank();
// pad_c2
Expand Down Expand Up @@ -3871,33 +3863,46 @@ LogicalResult verifyDynamicPadOp(std::optional<Location> location,
auto inputType = cast<RankedTensorType>(operand.getType());
int inputRank = inputType.getRank();

auto padType = cast<RankedTensorType>(paddingValue.getType());
if (padType.getRank() != 0)
return emitOptionalError(location, "padding value type should be a rank-0");

/*dynamic_pad_c2*/
// edgePaddingLow, edgePaddingHigh and interiorPadding are enforced to have
// the same size by ODS
auto paddingLowType = cast<RankedTensorType>(edgePaddingLow.getType());
if (paddingLowType.getNumElements() != inputRank)
return emitOptionalError(location, "edge_padding_low length(",
paddingLowType.getNumElements(),
") must match operand rank(", inputRank, ").");

auto paddingHighType = cast<RankedTensorType>(edgePaddingHigh.getType());
if (paddingHighType.getNumElements() != inputRank)
return emitOptionalError(location, "edge_padding_high length(",
paddingHighType.getNumElements(),
") must match operand rank(", inputRank, ").");

auto interiorPaddingType = cast<RankedTensorType>(interiorPadding.getType());
if (interiorPaddingType.getNumElements() != inputRank)
return emitOptionalError(location, "edge_padding_interior length(",
interiorPaddingType.getNumElements(),
") must match operand rank(", inputRank, ").");
auto paddingSize = paddingLowType.getDimSize(0);
if (paddingSize != inputRank)
return emitOptionalError(location, "padding operands size (", paddingSize,
") must match operand rank (", inputRank, ")");

/*dynamic_pad_c3*/
SmallVector<int64_t> interiorPaddingValues;
auto interiorPaddingMatched =
matchInts(interiorPadding, interiorPaddingValues);
if (succeeded(interiorPaddingMatched)) {
if (llvm::any_of(interiorPaddingValues, [](int64_t i) { return i < 0; }))
return emitOptionalError(
location, "interior_padding must be non-negative, but got ",
interiorPaddingValues);
};

auto outputType = cast<RankedTensorType>(result.getType());
int outputRank = outputType.getRank();
if (inputRank != outputRank)
return emitOptionalError(location, "operand rank(", inputRank,
") must match result(", outputRank, ").");
if (!inputType.hasStaticShape() || !outputType.hasStaticShape() ||
failed(interiorPaddingMatched))
return success();

SmallVector<int64_t> edgePaddingLowValues;
if (failed(matchInts(edgePaddingLow, edgePaddingLowValues))) return success();
SmallVector<int64_t> edgePaddingHighValues;
if (failed(matchInts(edgePaddingHigh, edgePaddingHighValues)))
return success();

/*dynamic_pad_c4*/
for (auto [i, in, out, low, high, interior] : llvm::enumerate(
inputType.getShape(), outputType.getShape(), edgePaddingLowValues,
edgePaddingHighValues, interiorPaddingValues)) {
auto want = in + low + std::max(in - 1, long(0)) * interior + high;
if (out != want)
return emitOptionalError(location, "expected output dimension at index ",
i, " to equal ", want, ", but got ", out);
}

return success();
}
Expand Down
40 changes: 20 additions & 20 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,14 @@ SmallVector<InterpreterValue> eval(Region &region,
scope.add(block.getArguments(), args);

for (Operation &operation : block) {
if (!llvm::all_of(operation.getResults(), [](OpResult r) {
if (auto shaped = dyn_cast<ShapedType>(r.getType()))
return shaped.hasStaticShape();
return true;
}))
llvm::report_fatal_error(
"dynamic result types are not supported at the moment");

if (auto op = dyn_cast<AbsOp>(operation)) {
auto operand = scope.findTensor(op.getOperand());
auto result = absOp(operand, op.getType());
Expand Down Expand Up @@ -581,13 +589,21 @@ SmallVector<InterpreterValue> eval(Region &region,
scope.add(op.getResult(), result);
} else if (auto op = dyn_cast<DynamicIotaOp>(operation)) {
auto iotaDimension = op.getIotaDimension();
auto outputShape = scope.findTensor(op.getOutputShape());
auto result = dynamicIotaOp(iotaDimension, outputShape, op.getType());
auto result = iotaOp(iotaDimension, op.getType());
scope.add(op.getResult(), result);
} else if (auto op = dyn_cast<DynamicPadOp>(operation)) {
auto operand = scope.findTensor(op.getOperand());
auto paddingValue = scope.findTensor(op.getPaddingValue());
auto edgePaddingLow = scope.findTensor(op.getEdgePaddingLow());
auto edgePaddingHigh = scope.findTensor(op.getEdgePaddingHigh());
auto interiorPadding = scope.findTensor(op.getInteriorPadding());
auto result =
padOp(operand, paddingValue, makeSizes(edgePaddingLow),
makeSizes(edgePaddingHigh), makeSizes(interiorPadding));
scope.add(op.getResult(), result);
} else if (auto op = dyn_cast<DynamicReshapeOp>(operation)) {
auto operand = scope.findTensor(op.getOperand());
auto outputShape = scope.findTensor(op.getOutputShape());
auto result = dynamicReshapeOp(operand, outputShape, op.getType());
auto result = reshapeOp(operand, op.getType());
scope.add(op.getResult(), result);
} else if (auto op = dyn_cast<DynamicSliceOp>(operation)) {
auto operand = scope.findTensor(op.getOperand());
Expand Down Expand Up @@ -1586,22 +1602,6 @@ Tensor dotGeneralOp(const Tensor &lhs, const Tensor &rhs,
return result;
}

Tensor dynamicIotaOp(Axis iotaDimension, const Tensor &outputShape,
ShapedType resultType) {
if (resultType.hasStaticShape()) return iotaOp(iotaDimension, resultType);

llvm::report_fatal_error(
"dynamic result types are not supported at the moment");
}

Tensor dynamicReshapeOp(const Tensor &operand, const Tensor &outputShape,
ShapedType resultType) {
if (resultType.hasStaticShape()) return reshapeOp(operand, resultType);

llvm::report_fatal_error(
"dynamic result types are not supported at the moment");
}

Tensor dynamicSliceOp(const Tensor &operand, ArrayRef<Tensor> startIndices,
const Sizes &sliceSizes, ShapedType resultType) {
Tensor result(resultType);
Expand Down
4 changes: 0 additions & 4 deletions stablehlo/reference/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ Tensor dotGeneralOp(const Tensor &lhs, const Tensor &rhs,
const Axes &lhsContractingDimensions,
const Axes &rhsContractingDimensions,
ShapedType resultType);
Tensor dynamicIotaOp(Axis iotaDimension, const Tensor &outputShape,
ShapedType resultType);
Tensor dynamicReshapeOp(const Tensor &operand, const Tensor &outputShape,
ShapedType resultType);
Tensor dynamicSliceOp(const Tensor &operand, ArrayRef<Tensor> startIndices,
const Sizes &sliceSizes, ShapedType resultType);
Tensor dynamicUpdateSliceOp(const Tensor &operand, const Tensor &update,
Expand Down
18 changes: 18 additions & 0 deletions stablehlo/reference/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/DebugStringHelper.h"
#include "stablehlo/reference/Errors.h"
#include "stablehlo/reference/Index.h"
#include "stablehlo/reference/Types.h"

namespace mlir {
Expand Down Expand Up @@ -561,5 +562,22 @@ DenseElementsAttr makeDenseElementsAttr(Tensor tensor) {
llvm::report_fatal_error("Only FloatType and IntType are handled currently.");
}

Sizes makeSizes(Tensor tensor) {
if (tensor.getRank() != 1 || !isa<IntegerType>(tensor.getElementType())) {
std::string str;
llvm::raw_string_ostream os(str);
os << "makeSizes(Tensor) only accepts integer tensors of rank 1, but got: ";
tensor.print(os);
llvm::report_fatal_error(str.c_str());
}
SmallVector<int64_t> values;
values.reserve(tensor.getNumElements());
for (auto it = tensor.index_begin(), end = tensor.index_end(); it != end;
it++) {
values.push_back(tensor.get(*it).getIntegerValue().getSExtValue());
}
return Sizes(values);
}

} // namespace stablehlo
} // namespace mlir
3 changes: 3 additions & 0 deletions stablehlo/reference/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ Tensor makeTensor(DenseElementsAttr attr);
/// Creates a DenseElementsAttr from a Tensor.
DenseElementsAttr makeDenseElementsAttr(Tensor tensor);

/// Creates a Sizes from a Tensor.
Sizes makeSizes(Tensor tensor);

} // namespace stablehlo
} // namespace mlir

Expand Down
Loading

0 comments on commit aa576d6

Please sign in to comment.