diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 17512de87a45..2f8d10beb4be 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -434,6 +434,8 @@ function clean_build() { } function build_torch_mlir() { + # Disable LTC build for releases to avoid linker issues + export TORCH_MLIR_ENABLE_LTC=0 local torch_version="$1" case $torch_version in nightly) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 261b4df3bd09..77d94eb0f8b9 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -191,6 +191,25 @@ struct OpBinder { return failure(); } + ParseResult stringArrayAttr(llvm::SmallVector &values, + StringRef nameSuffix) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) + return success(); + if (auto arrayAttr = dyn_cast(attr)) { + for (auto element : arrayAttr) { + StringAttr stringAttr = element.dyn_cast(); + if (!stringAttr) + return failure(); + values.push_back(stringAttr.getValue().str()); + } + return success(); + } + return failure(); + } + ParseResult denseElementsAttr(ElementsAttr elementsattr, StringRef nameSuffix) { SmallString<64> name("torch.onnx."); diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index ae35254f6092..b62f9dbaf4b5 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -10,7 +10,10 @@ #ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H #define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" namespace mlir::torch::onnx_c { @@ -20,6 +23,9 @@ Value createConstantIntList(OpBinder binder, Type getQTorchTypeFromTorchIntType(Type ty); +LogicalResult OnnxLstmExpander(OpBinder binder, + ConversionPatternRewriter &rewriter); + bool areAllElementsDistinct(SmallVector array); } // namespace mlir::torch::onnx_c diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index aef72d6ad404..39f04d8e7a6f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7585,6 +7585,31 @@ def Torch_Aten__And__TensorOp : Torch_Op<"aten.__and__.Tensor", [ }]; } +def Torch_Aten__And__ScalarOp : Torch_Op<"aten.__and__.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__and__.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__And__ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten__And__ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + def Torch_Aten__Or__TensorOp : Torch_Op<"aten.__or__.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt index 4a5015816609..ef3e51d45288 100644 --- a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt +++ b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch DefaultDomainAtoF.cpp DefaultDomainGtoP.cpp DefaultDomainQtoZ.cpp + OnnxLstmExpander.cpp Passes.cpp Patterns.cpp TorchOnnxToTorch.cpp diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 5d4e693d0a15..f661f0e02ebd 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1951,4 +1951,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); return success(); }); + patterns.onOp( + "Einsum", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + SmallVector tensors; + std::string equation; + if (binder.tensorOperands(tensors, binder.op->getNumOperands()) || + binder.customOpNameStringAttr(equation, "equation") || + binder.tensorResultType(resultType)) + return failure(); + Type listElemType = + tensors[0] + .getType() + .cast() + .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + Value tensorList = rewriter.create( + binder.op->getLoc(), listType, tensors); + Value cstEquation = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getStringAttr(equation)); + Value cstNone = rewriter.create(binder.getLoc()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 752fe97aded9..fb7332e91d37 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -195,6 +195,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, operand); return success(); }); + patterns.onOp("LSTM", 1, onnx_c::OnnxLstmExpander); patterns.onOp( "LogSoftmax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index f8c448e01a80..4a3ca533d242 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -538,7 +538,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); patterns.onOp( - "Squeeze", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Squeeze", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data; Value axes; diff --git a/lib/Conversion/TorchOnnxToTorch/OnnxLstmExpander.cpp b/lib/Conversion/TorchOnnxToTorch/OnnxLstmExpander.cpp new file mode 100644 index 000000000000..4c2ad051e0be --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/OnnxLstmExpander.cpp @@ -0,0 +1,514 @@ +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" + +using namespace mlir; +using namespace mlir::torch::Torch; +namespace mlir::torch::onnx_c { + +Value createActivationByName(ImplicitLocOpBuilder &b, StringRef name, + Value input) { + if (name == "Sigmoid") + return b.create(input.getType(), input); + if (name == "Tanh") + return b.create(input.getType(), input); + if (name == "Relu") + return b.create(input.getType(), input); + llvm_unreachable("Unsupported activation function"); +} + +// @struct LstmWeights +// @brief A structure to hold LSTM weights. +// +// Each W_ weight matrix should have shape [hidden_size, input_size]. +// Each R_ weight matrix should have shape [hidden_size, hidden_size]. +// Each bias vector should have shape [4 * hidden_size]. +struct LstmWeights { + Value W_i, W_o, W_f, W_c; + Value R_i, R_o, R_f, R_c; + Value Wb_i, Wb_o, Wb_f, Wb_c; + Value Rb_i, Rb_o, Rb_f, Rb_c; +}; +struct LstmActivations { + std::string f; + std::string g; + std::string h; +}; + +struct LstmCellState { + Value H; + Value C; +}; +// This function represents a Long Short-Term Memory (LSTM) cell operation. +// +// @param b A builder for constructing operations. +// @param Xt The input sequence. It has a shape of [batch_size, input_size]. +// @param H_prev The previous hidden state. It has a shape of [batch_size, +// hidden_size]. +// @param C_prev The previous cell state. It has a shape of [batch_size, +// hidden_size]. +// @param weights The weights for the LSTM cell. See @ref LstmWeights for shapes +// @param activations The activation functions for the LSTM cell. Members f,g,h +// correspond to f,g,h in https://onnx.ai/onnx/operators/onnx__LSTM.html +// @return The state of the LSTM cell after the operation. +LstmCellState lstm_cell(ImplicitLocOpBuilder &b, Value Xt, Value H_prev, + Value C_prev, LstmWeights weights, + LstmActivations activations) { + + auto intType = b.getType(); + auto hTy = cast(H_prev.getType()); + + Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + + // Apply linear/matmul for each gate separately + // names are consistent with ONNX LSTM documentation + Value i_x = b.create(hTy, Xt, weights.W_i, weights.Wb_i); + Value i_h = b.create(hTy, H_prev, weights.R_i, weights.Rb_i); + Value i = b.create(hTy, i_x, i_h, cstOne); + Value i_act = createActivationByName(b, activations.f, i); + + Value o_x = b.create(hTy, Xt, weights.W_o, weights.Wb_o); + Value o_h = b.create(hTy, H_prev, weights.R_o, weights.Rb_o); + Value o = b.create(hTy, o_x, o_h, cstOne); + Value o_act = createActivationByName(b, activations.f, o); + + Value f_x = b.create(hTy, Xt, weights.W_f, weights.Wb_f); + Value f_h = b.create(hTy, H_prev, weights.R_f, weights.Rb_f); + Value f = b.create(hTy, f_x, f_h, cstOne); + Value f_act = createActivationByName(b, activations.f, f); + + Value ct_x = b.create(hTy, Xt, weights.W_c, weights.Wb_c); + Value ct_h = b.create(hTy, H_prev, weights.R_c, weights.Rb_c); + Value ct = b.create(hTy, ct_x, ct_h, cstOne); + Value ct_act = createActivationByName(b, activations.g, ct); + + Value C_forget = b.create(hTy, f_act, C_prev); + Value C_input = b.create(hTy, i_act, ct_act); + + LstmCellState newCellState; + newCellState.C = b.create(hTy, C_forget, C_input, cstOne); + Value C_new_act = createActivationByName(b, activations.h, newCellState.C); + newCellState.H = b.create(hTy, o_act, C_new_act); + return newCellState; +} + +struct LstmLayerOutput { + Value Y; + Value Y_h; + Value Y_c; +}; + +// @brief This function implements the LSTM (Long Short-Term Memory) layer +// operation. +// +// The core computation is performed in a loop that iterates over the sequence +// length. In each iteration, it selects the corresponding input, computes the +// new hidden state and cell state using the lstm_cell function, and updates the +// output tensor. +// +// @return A struct containing the hidden state history, final hidden state, +// and final cell state. +LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h, + Value initial_c, LstmWeights weights, + LstmActivations activations) { + + Location loc = b.getLoc(); + + auto xTy = cast(X.getType()); + auto hTy = cast(initial_h.getType()); + // these names are snake_case for consistency with onnx.LSTM documentation + int64_t seq_len = xTy.getSizes()[0]; + int64_t batch_size = xTy.getSizes()[1]; + int64_t input_size = xTy.getSizes()[2]; + int64_t hidden_size = hTy.getSizes()[1]; + + auto cTy = hTy; + + auto intType = b.getType(); + + Value cstNone = b.create(); + Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); + Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + Value cstSeqLen = + b.create(intType, b.getI64IntegerAttr(seq_len)); + Value cstBatchSize = + b.create(intType, b.getI64IntegerAttr(batch_size)); + Value cstHiddenSize = + b.create(intType, b.getI64IntegerAttr(hidden_size)); + + auto yTy = b.getType( + SmallVector{seq_len, batch_size, hidden_size}, hTy.getDtype()); + + auto YShapeList = b.create( + b.getType(intType), + ValueRange({cstSeqLen, cstBatchSize, cstHiddenSize})); + + int64_t hDtypeInt = + static_cast(getScalarTypeForType(hTy.getDtype())); + Value hDtypeIntVal = + b.create(loc, b.getI64IntegerAttr(hDtypeInt)); + + Value Y_initial = b.create(yTy, YShapeList, hDtypeIntVal, + cstNone, cstNone, cstNone); + + // Create a for-like PrimLoopOp. + Value maxTripCount = + b.create(intType, b.getI64IntegerAttr(seq_len)); + Value loopConditionTrue = b.create(true); + + Type loopIndexType = intType; + auto loop = b.create( + TypeRange({yTy, hTy, cTy}), maxTripCount, loopConditionTrue, + ValueRange({Y_initial, initial_h, initial_c})); + { + OpBuilder::InsertionGuard guard(b); + Block *loopBody = + b.createBlock(&loop.getRegion(), loop.getRegion().begin(), + TypeRange({ + loopIndexType, + yTy, + hTy, + cTy, + }), + {loc, loc, loc, loc} // locs for the loop body arguments + ); + + Value loopIndex = loopBody->getArgument(0); + Value Y_prev = loopBody->getArgument(1); + Value H_prev = loopBody->getArgument(2); + Value C_prev = loopBody->getArgument(3); + + auto xTy = cast(X.getType()); + auto XtType = b.getType( + llvm::SmallVector{batch_size, input_size}, xTy.getDtype()); + + Value Xt = b.create(XtType, X, cstZero, loopIndex); + + auto [H_new, C_new] = + lstm_cell(b, Xt, H_prev, C_prev, weights, activations); + + Type hTyUnsqueezed = b.getType( + llvm::SmallVector{1, batch_size, hidden_size}, hTy.getDtype()); + Value H_new_unsqueezed = + b.create(hTyUnsqueezed, H_new, cstZero); + + auto loopIndexPlusOne = b.create(intType, loopIndex, cstOne); + Value Y_new = + b.create(yTy, Y_prev, H_new_unsqueezed, cstZero, + loopIndex, loopIndexPlusOne, cstOne); + + b.create(loopConditionTrue, + ValueRange({Y_new, H_new, C_new})); + } + LstmLayerOutput output; + output.Y = loop.getResult(0); + output.Y_h = loop.getResult(1); + output.Y_c = loop.getResult(2); + return output; +} +// @brief Expands an ONNX LSTM operation into torch ops. +// +// This function primarily handles the binding of operands and slicing of the +// weight matrix. The majority of the lowering process is managed in the +// lstm_layer and lstm_cell. For the shapes and meanings of the inputs, refer to +// the ONNX LSTM documentation at: +// https://onnx.ai/onnx/operators/onnx__LSTM.html +// The variable names are also consistent with the aforementioned documentation. +// +// This is not e2e tested here but is verified to work numerically downstream in +// SHARK-TestSuite. +// +// TODO: include this test case when the test infrastructure stops initializing +// weights separately for the reference and tested layers. +// @code{.py} +// class LSTMModule(torch.nn.Module): +// def __init__(self): +// super().__init__() +// self.lstm = torch.nn.LSTM(10, 20, 1) +// @export +// @annotate_args([ +// None, +// ([5, 1, 10], torch.float32, True), +// ([1, 1, 20], torch.float32, True), +// ([1, 1, 20], torch.float32, True), +// ]) +// def forward(self, input, h0, c0): +// return self.lstm(input, (h0, c0)) +// +// @register_test_case(module_factory=LSTMModule) +// def LSTMModule_basic(module, tu: TestUtils): +// inputs = torch.zeros(5,1,10) +// h0 = torch.zeros(1,1,20) +// c0 = torch.zeros(1,1,20) +// +// output, (hn, cn) = module.forward(inputs, h0, c0) +// @endcode +// +// @param binder The OpBinder object used for binding operands. +LogicalResult OnnxLstmExpander(OpBinder binder, + ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + mlir::ImplicitLocOpBuilder b(loc, rewriter); + + std::string direction; + + ValueTensorType yTy, Y_hType, Y_cType; + if (binder.tensorResultTypeAtIndex(yTy, 0) || + binder.tensorResultTypeAtIndex(Y_hType, 1) || + binder.tensorResultTypeAtIndex(Y_cType, 2)) { + return rewriter.notifyMatchFailure(binder.op, + "At least one outputs must be present"); + } + Value X; + if (binder.tensorOperandAtIndex(X, 0)) + return rewriter.notifyMatchFailure(binder.op, + "Missing required input tensor X"); + Value W; + if (binder.tensorOperandAtIndex(W, 1)) + return rewriter.notifyMatchFailure(binder.op, + "Missing required input tensor W"); + Value R; + if (binder.tensorOperandAtIndex(R, 2)) + return rewriter.notifyMatchFailure(binder.op, + "Missing required input tensor R"); + int64_t hidden_size; + if (binder.s64IntegerAttr(hidden_size, "hidden_size")) + return rewriter.notifyMatchFailure( + binder.op, "Missing required attribute hidden_size"); + + auto xTy = cast(X.getType()); + auto wTy = cast(W.getType()); + Value B; + if (binder.tensorOperandAtIndex(B, 3)) { + B = b.create(W.getType(), W); + } + + llvm::SmallVector activationsList; + if (binder.stringArrayAttr(activationsList, "activations")) + return rewriter.notifyMatchFailure( + binder.op, "Missing required attribute; activations"); + + LstmActivations activations; + activations.f = "Sigmoid"; + activations.g = "Tanh"; + activations.h = "Tanh"; + if (activationsList.size() == 3) { + activations.f = activationsList[0]; + activations.g = activationsList[1]; + activations.h = activationsList[2]; + } else if (activationsList.size() != 0) { + return rewriter.notifyMatchFailure( + binder.op, "activations must be empty have 3 elements, but " + + std::to_string(activationsList.size()) + + " are provided."); + } + + if (!binder.customOpNameStringAttr(direction, "direction", "forward") && + direction != "forward") + return rewriter.notifyMatchFailure(binder.op, + "Unsupported direction attribute value. " + "Only 'forward' is supported but '" + + direction + "' is provided."); + int64_t num_directions = 1 + (direction == "bidirectional"); + + auto XShape = xTy.getSizes(); + int64_t batch_size = XShape[1]; + int64_t input_size = XShape[2]; + if (num_directions != wTy.getSizes()[0]) + return rewriter.notifyMatchFailure( + binder.op, "num_directions (" + std::to_string(num_directions) + + ") does not match the first dimension of wTy (" + + std::to_string(wTy.getSizes()[0]) + ")"); + if (num_directions != 1) + return rewriter.notifyMatchFailure( + binder.op, "num_directions (" + std::to_string(num_directions) + + ") is not equal to 1"); + if (4 * hidden_size != wTy.getSizes()[1]) + return rewriter.notifyMatchFailure( + binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) + + ") does not match the second dimension of wTy (" + + std::to_string(wTy.getSizes()[1]) + ")"); + if (wTy.getSizes()[2] != input_size) + return rewriter.notifyMatchFailure( + binder.op, + "The third dimension of wTy (" + std::to_string(wTy.getSizes()[2]) + + ") does not match input_size (" + std::to_string(input_size) + ")"); + + /** + * @brief Splits the input tensor based on the provided direction. + * + * This function is used to split the LSTM parameters (W, R, B) into forward + * and backward directions. The input tensor is expected to have the forward + * and backward parameters concatenated along the 0th dimension. The function + * returns a tensor that contains the parameters for the specified direction. + * + * @param direction The direction to split out. 0 for forward, 1 for backward. + * @param input The input tensor to split. + * @return The split tensor for the specified direction. + */ + auto getDirection = [&](int64_t direction, Value input) { + auto inputType = cast(input.getType()); + + // drop 0th dimension + auto outputType = cast(inputType.getWithSizesAndDtype( + llvm::SmallVector{inputType.getSizes().drop_front()}, + inputType.getDtype())); + + auto intType = b.getType(); + Value selectDim = b.create(intType, b.getI64IntegerAttr(0)); + Value cstDirection = + b.create(intType, b.getI64IntegerAttr(direction)); + return b.create(outputType, input, selectDim, + cstDirection); + }; + + Value W_forward = getDirection(0, W); + Value R_forward = getDirection(0, R); + Value B_forward = getDirection(0, B); + + auto hTy = b.getType( + llvm::SmallVector{num_directions, batch_size, hidden_size}, + xTy.getDtype()); + + auto intType = b.getType(); + + Value cstNumDirections = + b.create(intType, b.getI64IntegerAttr(num_directions)); + Value cstBatchSize = + b.create(intType, b.getI64IntegerAttr(batch_size)); + Value cstHiddenSize = + b.create(intType, b.getI64IntegerAttr(hidden_size)); + Value cstNone = b.create(); + Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); + Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + + Value hShape = b.create( + b.getType(intType), + ValueRange({cstNumDirections, cstBatchSize, cstHiddenSize})); + + Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype()); + + Value initial_h; + if (binder.tensorOperandAtIndex(initial_h, 5)) { + initial_h = + b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); + } + Value initial_c; + if (binder.tensorOperandAtIndex(initial_c, 6)) { + initial_c = + b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); + } + + Value initial_h_forward = getDirection(0, initial_h); + Value initial_c_forward = getDirection(0, initial_c); + + if (num_directions != 1) { + return rewriter.notifyMatchFailure( + binder.op, "Unsupported num_directions. Only 1 is supported but " + + std::to_string(num_directions) + " is provided."); + // TODO: support bidirectional LSTM by doing both directions and replacing + // Unsqueeze with Stack + } + // Everything hereon is for the forward direction, with the direction + // dimention squeezed out. + + LstmWeights weights; // weights and biases + + auto intConst = [&](int64_t val) { + return b.create(intType, b.getI64IntegerAttr(val)); + }; + + // split B into Wb and Rb + Value inputWeightsEndIdx = intConst(4 * hidden_size); + Value recurrentWeightsStartIdx = inputWeightsEndIdx; + Value recurrentWeightsEndIdx = intConst(8 * hidden_size); + auto biasType = b.getType( + llvm::SmallVector{hidden_size * 4}, wTy.getDtype()); + Value Wb = b.create(biasType, + /*input=*/B_forward, + /*dim=*/cstZero, + /*start=*/cstZero, + /*end=*/inputWeightsEndIdx, + /*step=*/cstOne); + Value Rb = b.create(biasType, + /*input=*/B_forward, + /*dim=*/cstZero, + /*start=*/recurrentWeightsStartIdx, + /*end=*/recurrentWeightsEndIdx, + /*step=*/cstOne); + + // gate splitting + auto gateBiasType = b.getType( + llvm::SmallVector{hidden_size}, + cast(Wb.getType()).getDtype()); + auto gateWeightsTypeIH = b.getType( + llvm::SmallVector{hidden_size, input_size}, + cast(W_forward.getType()).getDtype()); + auto gateWeightsTypeHH = b.getType( + llvm::SmallVector{hidden_size, hidden_size}, + cast(R_forward.getType()).getDtype()); + + Value inputGateWeightsEndIdx = intConst(hidden_size); + Value outputGateWeightsEndIdx = intConst(2 * hidden_size); + Value forgetGateWeightsEndIdx = intConst(3 * hidden_size); + Value cellGateWeightsEndIdx = intConst(4 * hidden_size); + + auto sliceIOFC = [&](std::function slicerFunction) { + // slice into 4 components and return tuple + return std::make_tuple( + slicerFunction(cstZero, inputGateWeightsEndIdx), + slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx), + slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx), + slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx)); + }; + + auto sliceGateBias = [&](Value startIdx, Value endIdx) { + return b.create(gateBiasType, Wb, cstZero, startIdx, + endIdx, cstOne); + }; + std::tie(weights.Wb_i, weights.Wb_o, weights.Wb_f, weights.Wb_c) = + sliceIOFC(sliceGateBias); + + auto sliceGateBiasR = [&](Value startIdx, Value endIdx) { + return b.create(gateBiasType, Rb, cstZero, startIdx, + endIdx, cstOne); + }; + std::tie(weights.Rb_i, weights.Rb_o, weights.Rb_f, weights.Rb_c) = + sliceIOFC(sliceGateBiasR); + + auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx) { + return b.create(gateWeightsTypeIH, W_forward, cstZero, + startIdx, endIdx, cstOne); + }; + std::tie(weights.W_i, weights.W_o, weights.W_f, weights.W_c) = + sliceIOFC(sliceGateWeightsIH); + + auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx) { + return b.create(gateWeightsTypeHH, R_forward, cstZero, + startIdx, endIdx, cstOne); + }; + std::tie(weights.R_i, weights.R_o, weights.R_f, weights.R_c) = + sliceIOFC(sliceGateWeightsHH); + LstmLayerOutput lstmLayerOutput = lstm_layer( + b, X, initial_h_forward, initial_c_forward, weights, activations); + + auto Y_h_Y_c_unsqueezed_type = b.getType( + llvm::SmallVector{num_directions, batch_size, hidden_size}, + cast(lstmLayerOutput.Y_h.getType()).getDtype()); + Value Y_h_unsqueezed = b.create( + Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_h, cstZero); + Value Y_c_unsqueezed = b.create( + Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_c, cstZero); + + // unsqueeze num_directions dim1 of Y + // to create the onnx.LSTM output shape [seq_length, num_directions, + // batch_size, hidden_size] + Value Y_unsqueezed = + b.create(yTy, lstmLayerOutput.Y, cstOne); + + rewriter.replaceOp(binder.op, mlir::ValueRange{Y_unsqueezed, Y_h_unsqueezed, + Y_c_unsqueezed}); + return success(); +} +} // namespace mlir::torch::onnx_c diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 1a2a6cfc32dc..3a6c5396b3f8 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -577,13 +577,24 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.getSelf(); + Value rhs = adaptor.getOther(); + + RankedTensorType lhsTy = lhs.getType().dyn_cast(); + RankedTensorType rhsTy = rhs.getType().dyn_cast(); + + if (!lhsTy) + return op.emitError("lhs must be a ranked tensor type"); + TensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); - Value lhs = - hlo::promoteType(rewriter, op.getLoc(), adaptor.getSelf(), outType); - Value rhs = - hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType); + Type outElemTy = outType.getElementType(); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); + if (!rhsTy) { + rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); + } + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, @@ -1861,6 +1872,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalOrOp, chlo::BroadcastOrOp); INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalAndOp, chlo::BroadcastAndOp); INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalXorOp, chlo::BroadcastXorOp); + INSERT_BINARY_LOGICAL_PATTERN(AtenBitwiseAndScalarOp, chlo::BroadcastAndOp); + #undef INSERT_BINARY_LOGICAL_PATTERN #define INSERT_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 5e6cffaae05e..eab5164a2f6d 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1872,6 +1872,18 @@ void Aten__Or__TensorOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// Aten__And__ScalarOp +//===----------------------------------------------------------------------===// +void Aten__And__ScalarOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](Aten__And__ScalarOp op, PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getOther()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenScalarImplicitOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 4350bddb08e2..b4e3dae1a043 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7002,6 +7002,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.__and__.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.remainder.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11126,6 +11130,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.__and__.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.__and__.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 89dd2f70be8c..a6e8234b53b6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -541,6 +541,7 @@ "ElementwiseNeIntTensorStaticModule_basic", "ElementwiseNegModule_basic", "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseAndScalarStaticShapeModule_basic", "ElementwisePowTensorBroadcastStaticModule_basic", "ElementwisePowTensorStaticModule_basic", "ElementwisePreluStaticModule_basic", @@ -1703,6 +1704,7 @@ } ONNX_XFAIL_SET = { + # Failure - cast error "PermuteNegativeIndexModule_basic", @@ -1826,6 +1828,8 @@ "DivIntModule_basic", "ElementwiseAcoshIntModule_basic", "ElementwiseAcoshModule_basic", + "ElementwiseAndScalarModule_basic", + "ElementwiseAndScalarStaticShapeModule_basic", "ElementwiseAsinhIntModule_basic", "ElementwiseAsinhModule_basic", "ElementwiseAtanhIntModule_basic", @@ -2138,13 +2142,6 @@ # Failure - onnx_lowering: onnx.Clip "NormalizeModule_basic", - # Failure - onnx_lowering: onnx.Einsum - "EinsumStaticContractRhsModule_basic", - "EinsumStaticFourDimensionModule_basic", - "EinsumStaticModule_basic", - "EinsumStaticWithEllipsisSlicingModule_basic", - "EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic", - # Failure - onnx_lowering: onnx.MaxPool "MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", @@ -2354,4 +2351,3 @@ "IndexTensorMultiInputOneDim_basic", "IndexTensorMultiInput_basic", } - diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6d95d89681b5..6c163e3a46fa 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -498,6 +498,9 @@ def aten〇div〇Scalar〡shape(self: List[int], other: float) -> List[int]: def aten〇remainder〇Scalar〡shape(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇__and__〇Scalar〡shape(self: List[int], other: float) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇remainder〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -3123,6 +3126,15 @@ def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[ self_rank, self_dtype = self_rank_dtype return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇__and__〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + @check_dtype_function(_check_two_tensor_op()) def aten〇__and__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: other_rank, other_dtype = other_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 1a0a2bbf608a..6e2c756a3886 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -538,6 +538,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)") emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)") emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)") + emit("aten::__and__.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) emit("aten::__or__.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") emit("aten::mean : (Tensor, int?) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index a83393851d32..5cc418f11c53 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -5293,3 +5293,4 @@ def forward(self, x): @register_test_case(module_factory=lambda: CloneModule()) def CloneModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 5)) + diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 7a2d93450530..4f2849612c4f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -3208,6 +3208,51 @@ def ElementwiseOrTensorStaticShapeModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAndscalarModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + return torch.ops.aten.__and__(x, 12) + + +@register_test_case(module_factory=lambda: ElementwiseAndscalarModule()) +def ElementwiseAndScalarModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-10, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAndScalarStaticShapeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int32, True) + ]) + def forward(self, x): + return torch.ops.aten.__and__(x, 12) + + +@register_test_case(module_factory=lambda: ElementwiseAndScalarStaticShapeModule()) +def ElementwiseAndScalarStaticShapeModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-10, high=10).to(torch.int32)) + +# ============================================================================== + + class ElementwiseBitwiseXorModule(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchOnnxToTorch/ops/lstm.mlir b/test/Conversion/TorchOnnxToTorch/ops/lstm.mlir new file mode 100644 index 000000000000..bb1821088d12 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/ops/lstm.mlir @@ -0,0 +1,25 @@ +// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s + + + +// CHECK-LABEL: func.func @test_lstm_basic( +// CHECK-SAME: %[[X:.*]]: !torch.vtensor<[15,2,4],f32>, +// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[1,12,4],f32>, +// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[1,12,3],f32>, +// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[1,24],f32>) +// CHECK: %[[LOOP_RESULT:.*]]:3 = torch.prim.Loop %[[MAX_TRIPS:.*]], %[[ENTER_LOOP:.*]], init(%[[Y:.*]], %[[INITIAL_H:.*]], %[[INITIAL_C:.*]]) { +// CHECK: ^bb0(%[[LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV:.*]]: !torch.vtensor<[15,2,3],f32>, %[[H_PREV:.*]]: !torch.vtensor<[2,3],f32>, %[[C_PREV:.*]]: !torch.vtensor<[2,3],f32>): +// CHECK-DAG: torch.aten.select.int +// CHECK-DAG: torch.aten.linear +// CHECK-DAG: torch.aten.sigmoid +// CHECK-DAG: torch.aten.tanh +// CHECK-DAG: torch.prim.Loop.condition +// CHECK-DAG: } +// CHECK: } +module { + func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + %0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32> + } +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index fab5d7fb75c0..2aaf2a6977a2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1743,3 +1743,51 @@ func.func @test_compress_neg_axis(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch. %0 = torch.operator "onnx.Compress"(%arg0, %cst) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,2,4],f32> return %0 : !torch.vtensor<[2,2,4],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_einsum_batch_diagonal +func.func @test_einsum_batch_diagonal(%arg0: !torch.vtensor<[3,5,5],f64>) -> !torch.vtensor<[3,5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,5,5],f64>) -> !torch.list + // CHECK: %[[EQUATION:.*]] = torch.constant.str "...ii ->...i" + // CHECK: %[[PATH:.*]] = torch.constant.none + // CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list, !torch.none -> !torch.vtensor<[3,5],f64> + %0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "...ii ->...i"} : (!torch.vtensor<[3,5,5],f64>) -> !torch.vtensor<[3,5],f64> + return %0 : !torch.vtensor<[3,5],f64> +} + +// ----- + +// CHECK-LABEL: func.func @test_einsum_batch_matmul +func.func @test_einsum_batch_matmul(%arg0: !torch.vtensor<[5,2,3],f64>, %arg1: !torch.vtensor<[5,3,4],f64>) -> !torch.vtensor<[5,2,4],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[5,2,3],f64>, !torch.vtensor<[5,3,4],f64>) -> !torch.list + // CHECK: %[[EQUATION:.*]] = torch.constant.str "bij, bjk -> bik" + // CHECK: %[[PATH:.*]] = torch.constant.none + // CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list, !torch.none -> !torch.vtensor<[5,2,4],f64> + %0 = torch.operator "onnx.Einsum"(%arg0, %arg1) {torch.onnx.equation = "bij, bjk -> bik"} : (!torch.vtensor<[5,2,3],f64>, !torch.vtensor<[5,3,4],f64>) -> !torch.vtensor<[5,2,4],f64> + return %0 : !torch.vtensor<[5,2,4],f64> +} + +// ----- + +// CHECK-LABEL: func.func @test_einsum_sum +func.func @test_einsum_sum(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,4],f64>) -> !torch.list + // CHECK: %[[EQUATION:.*]] = torch.constant.str "ij->i" + // CHECK: %[[PATH:.*]] = torch.constant.none + // CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list, !torch.none -> !torch.vtensor<[3],f64> + %0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "ij->i"} : (!torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3],f64> + return %0 : !torch.vtensor<[3],f64> +} + +// ----- + +// CHECK-LABEL: func.func @test_einsum_transpose +func.func @test_einsum_transpose(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[4,3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,4],f64>) -> !torch.list + // CHECK: %[[EQUATION:.*]] = torch.constant.str "ij->ji" + // CHECK: %[[PATH:.*]] = torch.constant.none + // CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list, !torch.none -> !torch.vtensor<[4,3],f64> + %0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "ij->ji"} : (!torch.vtensor<[3,4],f64>) -> !torch.vtensor<[4,3],f64> + return %0 : !torch.vtensor<[4,3],f64> +} diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 93144daf9b15..bfacde280bdd 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -48,8 +48,8 @@ def sparse_metadata(a: torch.Tensor) -> SparsityMeta: sparse_dim, dense_dim, blocksize, - a.indices().dtype, - a.indices().dtype, + a._indices().dtype, + a._indices().dtype, ) elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: if a.layout is torch.sparse_bsr: @@ -373,3 +373,34 @@ def forward(self, x): print(res2) print("torch.mlir.batch") print(res3) + + +@run +# CHECK-LABEL: test_sparse_coo3 +# CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#sparse>) -> !torch.vtensor<[10,20,30],f64> { +# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#sparse> -> !torch.vtensor<[10,20,30],f64> +# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64> +# CHECK: } +# +# TODO: make sure sparsity propagates through relu into the output and test actual JIT output +# +def test_sparse_coo3(): + class COO3Net(torch.nn.Module): + def __init__(self): + super(COO3Net, self).__init__() + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(x) + + net = COO3Net() + + # Direct 3-dim COO construction. + idx = torch.tensor([[0, 1, 1, 4, 9, 9], [0, 1, 1, 5, 19, 19], [0, 1, 3, 6, 28, 29]]) + val = torch.tensor([-1000.0, -1.0, 1.0, 2.0, 3.0, 1000.0], dtype=torch.float64) + sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20, 30]) + + m = export_and_import(net, sparse_input) + print(m)