diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index fb6936452e84..802dad692e31 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13577,6 +13577,31 @@ def Torch_AtenSplitSizesOp : Torch_Op<"aten.split.sizes", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenTensorSplitSectionsOp : Torch_Op<"aten.tensor_split.sections", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::tensor_split.sections : (Tensor, int, int) -> (Tensor[])`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$sections, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTensorSplitSectionsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenTensorSplitSectionsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; } def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 6d5c6d472e0d..314191a2c428 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3065,6 +3065,19 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenSplitSizesOp +//===----------------------------------------------------------------------===// + +void AtenSplitSizesOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenSplitSizesOp op, PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenIsFloatingPointOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 05e2f42cad30..0edb9bc51f2d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -862,18 +862,6 @@ class DecomposePrimTolistOp : public OpRewritePattern { }; } // namespace -namespace { -class DecomposeAtenSplitSizesOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenSplitSizesOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim()); - return success(); - } -}; -} // namespace - namespace { class DecomposeAtenSplitWithSizesOp : public OpRewritePattern { @@ -8206,7 +8194,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index d61323248021..d655d53412cf 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -13,6 +13,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include using namespace mlir; using namespace mlir::torch; @@ -164,7 +165,7 @@ class RecomposeUnbindListUnpack : public OpRewritePattern { LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenUnbindOp + PrimListUnpackOp to select.int - auto unbindOp = dyn_cast(op.getOperand().getDefiningOp()); + auto unbindOp = op.getOperand().getDefiningOp(); if (!unbindOp) return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp"); if (isListPotentiallyMutated(unbindOp.getResult())) @@ -207,7 +208,7 @@ class RecomposeUnbindGetItem : public OpRewritePattern { LogicalResult matchAndRewrite(Aten__Getitem__TOp op, PatternRewriter &rewriter) const override { // recompose AtenUnbindIntOp + __getitem__t to select.int - auto unbind = dyn_cast(op.getList().getDefiningOp()); + auto unbind = op.getList().getDefiningOp(); if (!unbind) return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp"); if (isListPotentiallyMutated(unbind.getResult())) @@ -287,15 +288,14 @@ class RecomposeSplitTensorPrimListUnpackOp } }; -class RecomposeSplitTensorGetItemOp +class RecomposeSplitTensorGetItem : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten__Getitem__TOp op, PatternRewriter &rewriter) const override { // recompose AtenSplitTensorOp + __getitem__t to AtenSliceTensorOp - auto splitTensorOp = - dyn_cast(op.getList().getDefiningOp()); + auto splitTensorOp = op.getList().getDefiningOp(); if (!splitTensorOp) return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp"); if (isListPotentiallyMutated(splitTensorOp.getResult())) @@ -352,8 +352,7 @@ class RecomposeSplitTensorListUnpack LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenSplitTensorOp + PrimListUnpackOp to AtenSliceTensorOps - auto splitTensorOp = - dyn_cast(op.getOperand().getDefiningOp()); + auto splitTensorOp = op.getOperand().getDefiningOp(); if (!splitTensorOp) return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp"); if (isListPotentiallyMutated(splitTensorOp.getResult())) @@ -406,6 +405,78 @@ class RecomposeSplitTensorListUnpack } }; +class RecomposeSplitWithSizesGetItem + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten__Getitem__TOp op, + PatternRewriter &rewriter) const override { + // recompose AtenSplitWithSizes + __getitem__t to AtenSliceTensorOp + auto splitWithSizesOp = op.getList().getDefiningOp(); + if (!splitWithSizesOp) + return rewriter.notifyMatchFailure(op, + "Input is not AtenSplitWithSizesOp"); + if (isListPotentiallyMutated(splitWithSizesOp.getResult())) + return rewriter.notifyMatchFailure( + op, "AtenSplitWithSizesOp result is potentially mutated"); + if (isListPotentiallyMutated(splitWithSizesOp.getSplitSizes())) { + return rewriter.notifyMatchFailure( + op, "splitWithSizesOp's split_sizes is potentially mutated"); + } + + SmallVector splitSizes; + if (!matchPattern(splitWithSizesOp.getSplitSizes(), + m_TorchListOfConstantInts(splitSizes))) { + return rewriter.notifyMatchFailure( + op, "split_sizes must be list of constant int"); + } + + int64_t index; + if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index))) + return rewriter.notifyMatchFailure( + op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int"); + index = toPositiveDim(index, splitSizes.size()); + if (!isValidDim(index, splitSizes.size())) + return rewriter.notifyMatchFailure( + op, "Expected `idx` in range of split_sizes"); + + Location loc = op.getLoc(); + Value input = splitWithSizesOp.getSelf(); + Value dim = splitWithSizesOp.getDim(); + + // add runtime.assert to check dimension constraint + Value totalSize = rewriter.create(loc, input, dim); + int64_t sumSplitSize = + std::accumulate(splitSizes.begin(), splitSizes.end(), 0); + Value cstSumSplitSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(sumSplitSize)); + Value eqOrNot = + rewriter.create(loc, totalSize, cstSumSplitSize); + rewriter.create( + loc, eqOrNot, + rewriter.getStringAttr("split dim must be sum of split_sizes")); + + // replace with AtenSliceTensorOp + SmallVector boundaryOfSliceOp(splitSizes.size() + 1, 0); + for (size_t i = 1; i < boundaryOfSliceOp.size(); i++) { + boundaryOfSliceOp[i] = boundaryOfSliceOp[i - 1] + splitSizes[i - 1]; + } + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto start = rewriter.create( + loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[index])); + auto end = rewriter.create( + loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[index + 1])); + Value slice = rewriter.create( + loc, op.getType(), input, dim, start, end, /*step=*/cstOne); + rewriter.replaceOp(op, slice); + // erase splitOp if no user left + if (splitWithSizesOp.getResult().use_empty()) + rewriter.eraseOp(splitWithSizesOp); + return success(); + } +}; + class RecomposeSplitWithSizesListUnpack : public OpRewritePattern { public: @@ -413,8 +484,7 @@ class RecomposeSplitWithSizesListUnpack LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenSplitWithSizesOp + PrimListUnpackOp to AtenSliceTensorOps - auto splitOp = - dyn_cast(op.getOperand().getDefiningOp()); + auto splitOp = op.getOperand().getDefiningOp(); if (!splitOp) { return rewriter.notifyMatchFailure(op, "Input is not AtenSplitWithSizesOp"); @@ -434,20 +504,11 @@ class RecomposeSplitWithSizesListUnpack op, "split_sizes is not from PrimListConstructOp"); } - int64_t sumSplitSize = 0; SmallVector splitSizes; - for (auto operand : splitSizesConstruct.getOperands()) { - int64_t value = -1; - // TODO: support when split_sizes are not constant int - if (!matchPattern(operand, m_TorchConstantInt(&value))) { - return rewriter.notifyMatchFailure( - op, "one of split_sizes is not constant int"); - } - if (value < 0) { - return rewriter.notifyMatchFailure(op, "all of split_sizes must > 0"); - } - sumSplitSize += value; - splitSizes.push_back(value); + if (!matchPattern(splitOp.getSplitSizes(), + m_TorchListOfConstantInts(splitSizes))) { + return rewriter.notifyMatchFailure( + op, "split_sizes must be list of constant int"); } if (splitSizes.size() != op.getNumResults()) { return rewriter.notifyMatchFailure( @@ -460,6 +521,8 @@ class RecomposeSplitWithSizesListUnpack // add runtime.assert to check rank constraint Value totalSize = rewriter.create(loc, input, dim); + int64_t sumSplitSize = + std::accumulate(splitSizes.begin(), splitSizes.end(), 0); Value cstSumSplitSize = rewriter.create( loc, rewriter.getI64IntegerAttr(sumSplitSize)); Value eqOrNot = @@ -494,13 +557,156 @@ class RecomposeSplitWithSizesListUnpack } }; +class RecomposeTensorSplitSectionsGetItem + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten__Getitem__TOp op, + PatternRewriter &rewriter) const override { + // recompose AtenTensorSplitSectionsOp + __getitem__t to AtenSliceTensorOp + auto splitOp = op.getList().getDefiningOp(); + if (!splitOp) + return rewriter.notifyMatchFailure( + op, "Input is not AtenTensorSplitSectionsOp"); + if (isListPotentiallyMutated(splitOp.getResult())) + return rewriter.notifyMatchFailure( + op, "AtenTensorSplitSectionsOp result is potentially mutated"); + + int64_t sections; + if (!matchPattern(splitOp.getSections(), m_TorchConstantInt(§ions))) + return rewriter.notifyMatchFailure( + op, "Expected `sections` of AtenTensorSplitSectionsOp to be a " + "constant int"); + + int64_t index; + if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index))) + return rewriter.notifyMatchFailure( + op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int"); + index = toPositiveDim(index, sections); + if (!isValidDim(index, sections)) + return rewriter.notifyMatchFailure( + op, "Expected `idx` in range of split_sizes"); + + Location loc = op.getLoc(); + Value input = splitOp.getSelf(); + Value dim = splitOp.getDim(); + + // only recompose to slice when split dim size is static, otherwise we need + // control flow like prim.if + Value dimSizeValue = rewriter.createOrFold(loc, input, dim); + int64_t splitDimSize; + if (!matchPattern(dimSizeValue, m_TorchConstantInt(&splitDimSize))) + return rewriter.notifyMatchFailure(splitOp, + "split dim size must be static"); + + int64_t chunkSize = splitDimSize / sections; + int64_t remain = splitDimSize % sections; + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value result; + if (index < remain) { + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(index * (chunkSize + 1))); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr((index + 1) * (chunkSize + 1))); + result = rewriter.create(loc, op.getType(), input, dim, + start, end, + /*step=*/cstOne); + } else { + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(index * chunkSize + remain)); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr((index + 1) * chunkSize + remain)); + result = rewriter.create(loc, op.getType(), input, dim, + start, end, + /*step=*/cstOne); + } + rewriter.replaceOp(op, result); + // erase AtenTensorSplitSectionsOp if no user left + if (splitOp.getResult().use_empty()) + rewriter.eraseOp(splitOp); + return success(); + } +}; + +class RecomposeTensorSplitSectionsListUnpack + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListUnpackOp op, + PatternRewriter &rewriter) const override { + // recompose AtenTensorSplitSectionsOp + PrimListUnpackOp to + // AtenSliceTensorOps + auto splitOp = op.getOperand().getDefiningOp(); + if (!splitOp) + return rewriter.notifyMatchFailure( + op, "Input is not AtenTensorSplitSectionsOp"); + if (isListPotentiallyMutated(splitOp.getResult())) + return rewriter.notifyMatchFailure( + op, "AtenTensorSplitSectionsOp result is potentially mutated"); + + int64_t sections; + if (!matchPattern(splitOp.getSections(), m_TorchConstantInt(§ions))) + return rewriter.notifyMatchFailure( + op, "Expected `sections` of AtenTensorSplitSectionsOp to be a " + "constant int"); + if (op->getNumResults() != sections) + return rewriter.notifyMatchFailure( + op, "`sections` must be same as ListUnpack's NumResults"); + + Location loc = op.getLoc(); + Value input = splitOp.getSelf(); + Value dim = splitOp.getDim(); + + // only recompose to slice when split dim size is static, otherwise we need + // control flow like prim.if + Value dimSizeValue = rewriter.createOrFold(loc, input, dim); + int64_t splitDimSize; + if (!matchPattern(dimSizeValue, m_TorchConstantInt(&splitDimSize))) + return rewriter.notifyMatchFailure(splitOp, + "split dim size must be static"); + + int64_t chunkSize = splitDimSize / sections; + int64_t remain = splitDimSize % sections; + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + SmallVector results; + for (int64_t i = 0; i < sections; i++) { + if (i < remain) { + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(i * (chunkSize + 1))); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr((i + 1) * (chunkSize + 1))); + Value slice = rewriter.create( + loc, op.getResult(i).getType(), input, dim, start, end, + /*step=*/cstOne); + results.push_back(slice); + } else { + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(i * chunkSize + remain)); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr((i + 1) * chunkSize + remain)); + Value slice = rewriter.create( + loc, op.getResult(i).getType(), input, dim, start, end, + /*step=*/cstOne); + results.push_back(slice); + } + } + rewriter.replaceOp(op, results); + // erase AtenTensorSplitSectionsOp if no user left + if (splitOp.getResult().use_empty()) + rewriter.eraseOp(splitOp); + return success(); + } +}; + class RecomposeChunkListUnpack : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenChunkOp + PrimListUnpackOp to AtenSliceTensorOps - auto chunkOp = dyn_cast(op.getOperand().getDefiningOp()); + auto chunkOp = op.getOperand().getDefiningOp(); if (!chunkOp) return rewriter.notifyMatchFailure(op, "Input is not AtenChunkOp"); if (isListPotentiallyMutated(chunkOp.getResult())) @@ -514,10 +720,13 @@ class RecomposeChunkListUnpack : public OpRewritePattern { // chunkSize = floordiv(totalSize + chunks - 1, chunks) Value chunkSize = getIntCeilDiv(rewriter, loc, totalSize, chunks); - // add runtime.assert to check chunks == NumResults + // add runtime.assert to check floordiv(totalSize + chunkSize - 1, + // chunkSize) == NumResults Value cstNumResults = rewriter.create( loc, rewriter.getI64IntegerAttr(op.getNumResults())); - Value eqOrNot = rewriter.create(loc, chunks, cstNumResults); + Value realChunks = getIntCeilDiv(rewriter, loc, totalSize, chunkSize); + Value eqOrNot = + rewriter.create(loc, realChunks, cstNumResults); rewriter.create( loc, eqOrNot, rewriter.getStringAttr( @@ -608,9 +817,15 @@ class RecomposeComplexOpsPass // pattern.add calls go here patterns.add(context); patterns.add(context); - patterns.add(context); + + // TODO: cloud move these patterns to Decompose pass, but should handle + // shape and value semantics carefully + patterns.add(context); patterns.add(context); + patterns.add(context); patterns.add(context); + patterns.add(context); + patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6fde0df51c89..184de282d85e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -24,7 +24,6 @@ "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "IscloseStaticModule_basic", "IscloseStaticModuleTrue_basic", - "SplitWithSizes_Module_basic", # lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec # these interpolate tests are added specifically to test onnx.Resize. "InterpolateDynamicModule_sizes_bilinear", @@ -847,6 +846,9 @@ } STABLEHLO_PASS_SET = { + "SplitWithSizes_Module_basic", + "TensorSplitSections_GetItemModule_basic", + "TensorSplitSections_ListUnpackModule_basic", "AtenLinear1D_basic", "AtenLinear2D_basic", "AtenLinear3DBias_basic", @@ -1509,6 +1511,8 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "TensorSplitSections_GetItemModule_basic", + "TensorSplitSections_ListUnpackModule_basic", "AtenLinear2D_basic", "AtenLinear3DBias_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", @@ -2785,6 +2789,10 @@ "_ConvolutionDeprecated2DDeterministicModule_basic", "_SoftmaxModule_basic", # Failure - onnx_import + # Failure - onnx_lowering: onnx.SplitToSequence + "ChunkListUnpackUneven_Module_basic", + "TensorSplitSections_GetItemModule_basic", + "TensorSplitSections_ListUnpackModule_basic", # Failure - onnx_lowering: onnx.AveragePool "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", # these diagonal modules are currently failing due to dynamic shape. diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 3fc03e4d2f9d..d23ac9ac45d4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -978,7 +978,10 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)", has_folder=True) emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])") emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])") - emit("aten::split.sizes : (Tensor, int[], int) -> (Tensor[])") + emit( + "aten::split.sizes : (Tensor, int[], int) -> (Tensor[])", has_canonicalizer=True + ) + emit("aten::tensor_split.sections : (Tensor, int, int) -> (Tensor[])") emit("aten::unbind.int : (Tensor, int) -> (Tensor[])") emit("aten::chunk : (Tensor, int, int) -> (Tensor[])") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py index 50043c35df01..c892499ea05e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -1159,8 +1159,8 @@ def __init__(self): ] ) def forward(self, x): - chunk_0, chunk_1, chunk_2 = torch.chunk(x, 3, 1) - return torch.ops.aten.add(chunk_0, chunk_1), chunk_2 + a0, a1, a2, a3, a4 = torch.chunk(x, 6, 1) + return a0, a1, a2, a3, a4 @register_test_case(module_factory=lambda: ChunkListUnpackUneven_Module()) @@ -1240,3 +1240,48 @@ def forward(self, x): @register_test_case(module_factory=lambda: SplitWithSizes_Module()) def SplitWithSizes_Module_basic(module, tu: TestUtils): module.forward(tu.rand(5, 2, 2)) + + +# ============================================================================== + + +class TensorSplitSections_GetItemModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 5], torch.float32, True), + ] + ) + def forward(self, x): + split = torch.tensor_split(x, 3, dim=1) + return split[0], split[1], split[2] + + +@register_test_case(module_factory=lambda: TensorSplitSections_GetItemModule()) +def TensorSplitSections_GetItemModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5)) + + +class TensorSplitSections_ListUnpackModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 5], torch.float32, True), + ] + ) + def forward(self, x): + a, b, c, d = torch.tensor_split(x, 4, dim=1) + return a, b, c, d + + +@register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule()) +def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5))