Skip to content

Commit

Permalink
dynamic_reshape op spec (openxla#2284)
Browse files Browse the repository at this point in the history
* Constraints in the spec:
```
I1 operand is tensor or quantized tensor
I2 output_shape is 1-dimensional tensor constant of type si64
```
* Test coverage;
```
I1.1  operand is tensor or quantized tensor.   (covered by ODS)
I2.1  output_shape is not a 1-dimensional tensor constant of type si64 .   (covered by ODS)
C1.1 if !is_per_axis_quantized(operand),  element_type(result) = element_type(operand).    (Added a verifier logic and a  test)
C1.2 if is_per_axis_quantized(operand),  element_type(result) = element_type(operand) except
       quantization_dimension(operand) and quantization_dimension(result) may differ.    (Validated; but test can't be added at this moment because interpreter doesn't support dynamic shapes) 
C2.1  size(operand) = size(result).  (can't be verified at static time for dynamic ops)
C3.1  quantization constraints.  (verifier logic is common for reshape and dynamic_reshape op. Tests are already present for reshape op. Not needed for dynamic_reshape.)
C4.1  rank(result) = size(output_shape)   (Already present, validated)
```

* Reference interpreter:
partial implementation to support only static shapes.




ref:openxla#2267
fixes:openxla#2293
  • Loading branch information
abhigunj authored May 8, 2024
1 parent 4213d63 commit bac0cd1
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 72 deletions.
64 changes: 59 additions & 5 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -2646,16 +2646,16 @@ op, but the result shape is specified dynamically via `output_shape`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|------------------|------------------------------------------------------------------------------------|-------------|
| (I1) | `output_shape` | 1-dimensional tensor constant of type `si64` | (C1), (C2) |
| (I2) | `iota_dimension` | `si64` | (C1) |
| Label | Name | Type | Constraints |
|-------|------------------|----------------------------------------------|-------------|
| (I1) | `output_shape` | 1-dimensional tensor constant of type `si64` | (C1), (C2) |
| (I2) | `iota_dimension` | `si64` | (C1) |

#### Outputs

| Name | Type | Constraints |
|----------|-----------------------------------------------------------------------------------|-------------|
| `result` | tensor of integer, floating-point, or complex type or per-tensor quantized tensor | (C2) |
| `result` | tensor of integer, floating-point, or complex type or per-tensor quantized tensor | (C2) |

#### Constraints

Expand All @@ -2680,6 +2680,60 @@ op, but the result shape is specified dynamically via `output_shape`.

 [More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/dynamic_iota.mlir)

### dynamic_reshape

#### Semantics

This operation is functionally identical to
[reshape](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape)
op, but the result shape is specified dynamically via `output_shape`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|----------------|----------------------------------------------|-------------|
| (I1) | `operand` | tensor or quantized tensor | (C1-C3) |
| (I2) | `output_shape` | 1-dimensional tensor constant of type `si64` | (C4) |

#### Outputs

| Name | Type | Constraints |
|----------|----------------------------|-------------|
| `result` | tensor or quantized tensor | (C1-C4) |

#### Constraints

* (C1) `element_type(result)` is given by:
* `element_type(operand)`, if `!is_per_axis_quantized(operand)`.
* `element_type(operand)` except that `quantization_dimension(operand)` and
`quantization_dimension(result)` may differ, otherwise.
* (C2) `size(operand) = size(result)`.
* (C3) If `is_per_axis_quantized(operand)`:
* `reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]),
init_values=1, dimensions=[0], body=lambda x, y: x * y) =
reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]),
init_values=1, dimensions=[0], body=lambda x, y: x * y)`.
* `dim(operand, quantization_dimension(operand)) =
dim(result, quantization_dimension(result))`.
* `reduce(dims(operand,
[quantization_dimension(operand) + 1, ..., rank(operand) - 1]),
init_values=1, dimensions=[0], body=lambda x, y: x * y) =
reduce(dims(result,
[quantization_dimension(result) + 1, ..., rank(result) - 1]),
init_values=1, dimensions=[0], body=lambda x, y: x * y)`.
* (C4) `size(output_shape) = rank(result)`.

#### Examples

```mlir
// %operand: [[1, 2, 3], [4, 5, 6]]
%output_shape = stablehlo.constant dense<[3, 2]> : tensor<2xi64>
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
```

&nbsp;[More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/dynamic_reshape.mlir)

### dynamic_slice

#### Semantics
Expand Down
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ one of the following tracking labels.
| dynamic_gather | no | revisit | revisit | no | no |
| dynamic_iota | yes | yes | infeasible | yes | revisit |
| dynamic_pad | no | revisit | no | yes | no |
| dynamic_reshape | no | revisit | infeasible | yes | no |
| dynamic_reshape | yes | yes | infeasible | yes | revisit |
| dynamic_slice | yes | yes | yes | yes | yes |
| dynamic_update_slice | yes | yes | yes | yes | yes |
| einsum | no | revisit | no | yes | revisit |
Expand Down
20 changes: 12 additions & 8 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2663,21 +2663,25 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape",
[HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect]> {
let summary = "DynamicReshape operation";
let description = [{
This operation is a work in progress, so it is not yet included in
the StableHLO specification: https://github.com/openxla/stablehlo/issues/8.
This operation is functionally identical to
[reshape](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape)
op, but the result shape is specified dynamically via `output_shape`.

Informally, this operation does the same thing as ReshapeOp except that the
result shape is specified dynamically via `output_shape`:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_reshape

Example:
```mlir
%0 = stablehlo.dynamic_reshape %arg0, %shape : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
%output_shape = stablehlo.constant dense<[3, 2]> : tensor<2xi64>
%result = stablehlo.dynamic_reshape %operand, %output_shape : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
```
}];

let arguments = (ins HLO_Tensor:$operand, HLO_StaticDimensionTensor:$output_shape);
let results = (outs HLO_Tensor:$result);
let arguments = (ins
HLO_TensorOrPerAxisQuantizedTensor:$operand /*dynamic_reshape_i1*/,
HLO_StaticDimensionTensor:$output_shape /*dynamic_reshape_i2*/
);
let results = (outs HLO_TensorOrPerAxisQuantizedTensor:$result);

let hasVerifier = 1;

Expand Down
140 changes: 88 additions & 52 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,61 @@ LogicalResult verifyQPerAxisScaleAndZeroPointConstraints(
return success();
}

LogicalResult verifyReshapeOpQuantizationConstraints(
std::optional<Location> location, Type operandTy, Type resultTy) {
// dynamic_reshape_c1, reshape_c1
if (failed(verifyQPerTensorScaleAndZeroPointConstraints(location, operandTy,
resultTy)))
return failure();

// dynamic_reshape_c1, reshape_c1
if (failed(verifyQPerAxisScaleAndZeroPointConstraints(location, operandTy,
resultTy)))
return failure();

// dynamic_reshape_c3, reshape_c3
if (allQuantized<quant::UniformQuantizedPerAxisType>(operandTy, resultTy)) {
auto operandQDim = cast<quant::UniformQuantizedPerAxisType>(
getElementTypeOrSelf(operandTy))
.getQuantizedDimension();
auto resultQDim =
cast<quant::UniformQuantizedPerAxisType>(getElementTypeOrSelf(resultTy))
.getQuantizedDimension();

auto operandShapeTy = cast<ShapedType>(operandTy);
auto resultShapeTy = cast<ShapedType>(resultTy);
if (!operandShapeTy.isDynamicDim(operandQDim) &&
!resultShapeTy.isDynamicDim(resultQDim) &&
operandShapeTy.getDimSize(operandQDim) !=
resultShapeTy.getDimSize(resultQDim)) {
return emitOptionalError(
location,
"expect same quantization dimension size for operand and result ",
operandTy, " and ", resultTy);
}

if (operandShapeTy.hasStaticShape() && resultShapeTy.hasStaticShape()) {
uint64_t operandProd = 1;
std::for_each(
operandShapeTy.getShape().begin(),
operandShapeTy.getShape().begin() + operandQDim,
[&operandProd](int32_t dimSize) { operandProd *= dimSize; });
uint64_t resultProd = 1;
std::for_each(resultShapeTy.getShape().begin(),
resultShapeTy.getShape().begin() + resultQDim,
[&resultProd](int32_t dimSize) { resultProd *= dimSize; });
if (operandProd != resultProd)
return emitOptionalError(
location,
"product of dimensions before quantization dimension must match "
"between operand and result for ",
operandProd, " and ", resultProd);
}
}

return success();
}

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3850,14 +3905,40 @@ LogicalResult verifyDynamicPadOp(std::optional<Location> location,
LogicalResult verifyDynamicReshapeOp(std::optional<Location> location,
Value operand, Value outputShape,
Value result) {
// dynamic_reshape_c1
if (!anyQuantized<quant::QuantizedType>(
{operand.getType(), result.getType()}) &&
!isCompatibleElementTypeForHloTypeInference(operand.getType(),
result.getType()))
return emitOptionalError(
location,
"expects operand and result to have compatible element type. Got: ",
operand.getType(), " and ", result.getType());

// dynamic_reshape_c2
auto resultType = cast<ShapedType>(result.getType());
auto operandType = cast<ShapedType>(operand.getType());
if (resultType.hasStaticShape() && operandType.hasStaticShape()) {
int64_t numResultElements = resultType.getNumElements();
int64_t numOperandElements = operandType.getNumElements();
if (numResultElements != numOperandElements)
return emitOptionalError(location, "number of output elements (",
numResultElements,
") doesn't match expected number of elements (",
numOperandElements, ")");
}

// dynamic_reshape_c4
if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape,
resultType)))
return failure();

auto outputShapeType = cast<ShapedType>(outputShape.getType());
if (outputShapeType.getDimSize(0) != resultType.getRank())
return emitOptionalError(location,
"output should have a rank equal to the number of "
"result should have a rank equal to the number of "
"elements in output_shape");

auto operandType = cast<RankedTensorType>(operand.getType());
if (SmallVector<int64_t> shape; operandType.hasStaticShape() &&
matchInts(outputShape, shape).succeeded()) {
int64_t operandCount = operandType.getNumElements();
Expand All @@ -3872,9 +3953,11 @@ LogicalResult verifyDynamicReshapeOp(std::optional<Location> location,
}
}

if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape,
resultType)))
return failure();
// dynamic_reshape_c1, dynamic_reshape_c3
if (anyQuantized<quant::QuantizedType>(operand.getType(), result.getType()))
return verifyReshapeOpQuantizationConstraints(location, operand.getType(),
result.getType());

return success();
}

Expand Down Expand Up @@ -4171,53 +4254,6 @@ LogicalResult verifyReduceWindowOp(
return success();
}

LogicalResult verifyReshapeOpQuantizationConstraints(
std::optional<Location> location, Type operandTy, Type resultTy) {
// reshape_c1
if (failed(verifyQPerTensorScaleAndZeroPointConstraints(location, operandTy,
resultTy)))
return failure();

// reshape_c1
if (failed(verifyQPerAxisScaleAndZeroPointConstraints(location, operandTy,
resultTy)))
return failure();

// reshape_c3
if (allQuantized<quant::UniformQuantizedPerAxisType>(operandTy, resultTy)) {
auto operandQDim = cast<quant::UniformQuantizedPerAxisType>(
getElementTypeOrSelf(operandTy))
.getQuantizedDimension();
auto resultQDim =
cast<quant::UniformQuantizedPerAxisType>(getElementTypeOrSelf(resultTy))
.getQuantizedDimension();
auto operandShape = cast<ShapedType>(operandTy).getShape();
auto resultShape = cast<ShapedType>(resultTy).getShape();

if (cast<ShapedType>(operandTy).getDimSize(operandQDim) !=
cast<ShapedType>(resultTy).getDimSize(resultQDim))
return emitOptionalError(
location,
"expect same quantization dimension size for operand and result ",
operandTy, " and ", resultTy);

uint64_t operandProd = 1;
std::for_each(operandShape.begin(), operandShape.begin() + operandQDim,
[&operandProd](int32_t dimSize) { operandProd *= dimSize; });
uint64_t resultProd = 1;
std::for_each(resultShape.begin(), resultShape.begin() + resultQDim,
[&resultProd](int32_t dimSize) { resultProd *= dimSize; });
if (operandProd != resultProd)
return emitOptionalError(
location,
"product of dimensions before quantization dimension must match "
"between operand and result for ",
operandProd, " and ", resultProd);
}

return success();
}

LogicalResult verifyReshapeOp(std::optional<Location> location, Value operand,
Value result) {
// If the operand type is dynamically shaped there is nothing to verify.
Expand Down
13 changes: 13 additions & 0 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,11 @@ SmallVector<InterpreterValue> eval(Region &region,
auto outputShape = scope.findTensor(op.getOutputShape());
auto result = dynamicIotaOp(iotaDimension, outputShape, op.getType());
scope.add(op.getResult(), result);
} else if (auto op = dyn_cast<DynamicReshapeOp>(operation)) {
auto operand = scope.findTensor(op.getOperand());
auto outputShape = scope.findTensor(op.getOutputShape());
auto result = dynamicReshapeOp(operand, outputShape, op.getType());
scope.add(op.getResult(), result);
} else if (auto op = dyn_cast<DynamicSliceOp>(operation)) {
auto operand = scope.findTensor(op.getOperand());
auto startIndices = scope.findTensors(op.getStartIndices());
Expand Down Expand Up @@ -1589,6 +1594,14 @@ Tensor dynamicIotaOp(Axis iotaDimension, const Tensor &outputShape,
"dynamic result types are not supported at the moment");
}

Tensor dynamicReshapeOp(const Tensor &operand, const Tensor &outputShape,
ShapedType resultType) {
if (resultType.hasStaticShape()) return reshapeOp(operand, resultType);

llvm::report_fatal_error(
"dynamic result types are not supported at the moment");
}

Tensor dynamicSliceOp(const Tensor &operand, ArrayRef<Tensor> startIndices,
const Sizes &sliceSizes, ShapedType resultType) {
Tensor result(resultType);
Expand Down
2 changes: 2 additions & 0 deletions stablehlo/reference/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ Tensor dotGeneralOp(const Tensor &lhs, const Tensor &rhs,
ShapedType resultType);
Tensor dynamicIotaOp(Axis iotaDimension, const Tensor &outputShape,
ShapedType resultType);
Tensor dynamicReshapeOp(const Tensor &operand, const Tensor &outputShape,
ShapedType resultType);
Tensor dynamicSliceOp(const Tensor &operand, ArrayRef<Tensor> startIndices,
const Sizes &sliceSizes, ShapedType resultType);
Tensor dynamicUpdateSliceOp(const Tensor &operand, const Tensor &update,
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/tests/interpret/dynamic_reshape.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

func.func @dynamic_reshape_op_test_si64() {
%operand = stablehlo.constant dense<[[1, 2, 3, 4, 5, 6]]> : tensor<1x6xi64>
%output_shape = stablehlo.constant dense<[6]> : tensor<1xi64>
%result = stablehlo.dynamic_reshape %operand, %output_shape : (tensor<1x6xi64>, tensor<1xi64>) -> tensor<6xi64>
check.expect_eq_const %result, dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
func.return
}
18 changes: 17 additions & 1 deletion stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3181,8 +3181,24 @@ func.func @dynamic_reshape(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> ten

// -----

func.func @dynamic_reshape_c1(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?x?xf64> {
// expected-error @+1 {{expects operand and result to have compatible element type}}
%0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf64>
func.return %0 : tensor<?x?xf64>
}

// -----

func.func @dynamic_reshape_c2(%arg0: tensor<11xf32>, %shape: tensor<2xindex>) -> tensor<2x5xf32> {
// expected-error @+1 {{number of output elements (10) doesn't match expected number of elements}}
%0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor<11xf32>, tensor<2xindex>) -> tensor<2x5xf32>
func.return %0 : tensor<2x5xf32>
}

// -----

func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?xf32> {
// expected-error @+1 {{output should have a rank equal to the number of elements in output_shape}}
// expected-error @+1 {{result should have a rank equal to the number of elements in output_shape}}
%0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?xf32>
func.return %0 : tensor<?xf32>
}
Expand Down
Loading

0 comments on commit bac0cd1

Please sign in to comment.