diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 5862549d3bfb..268ba574dd5c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10816,8 +10816,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.to.prim_Device\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" return %0#1 : !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.transpose.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 4f62518e62eb..2de81537aa15 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5475,6 +5475,33 @@ class DecomposeAtenToDtypeLayoutOp }; } // namespace +namespace { +// Decompose `aten.to.prim_Device` op into `aten.to.dtype` op. +class DecomposeAtenToPrimDeviceOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenToPrimDeviceOp op, + PatternRewriter &rewriter) const override { + + // Device information isn't relevant to torch-mlir, so we can drop that info + // here. + auto loc = op.getLoc(); + Value constNone = rewriter.create(loc); + + Value dtype = op.getDtype(); + if (dtype.getType().template isa()) { + dtype = rewriter.create(loc, op.getSelf()); + } + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + dtype, op.getNonBlocking(), + op.getCopy(), constNone); + + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.to.device` op into `aten.to.dtype` op. class DecomposeAtenToDeviceOp : public OpRewritePattern { @@ -7700,6 +7727,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index b434976bf4fb..250577b132ee 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -475,6 +475,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1e9ce5fddc99..0b220d30a603 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2876,7 +2876,9 @@ def aten〇t〡dtype(self_rank_dtype: Tuple[int, int]) -> int: @check_dtype_function(_check_tensors_with_the_same_dtype(1, tensor_device="meta", device=torch.device("meta"))) def aten〇to〇prim_Device〡dtype(self_rank_dtype: Tuple[int, int], device: Optional[device], dtype: Optional[int] = None, non_blocking: bool = False, copy: bool = False) -> int: self_rank, self_dtype = self_rank_dtype - return self_dtype + if dtype is None: + return self_dtype + return dtype @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim0=0, dim1=1)) def aten〇transpose〇int〡dtype(self_rank_dtype: Tuple[int, int], dim0: int, dim1: int) -> int: diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 09efb313d6cf..285c7cb73ba2 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -e1c8416590b718ab97e06089e57b650ae3909711 +a3fc530d821d9ebb6fc2fc9be4720932913e68c1