diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 02853b14072a..4f57263077c3 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -497,8 +497,7 @@ class ConvertAtenEmptyMemoryFormatOp cast(typeConverter->convertType(op.getType())); Type resultElementType; if (isa(op.getDtype().getType())) { - resultElementType = getDefaultDtypeForTorchScalar( - Torch::FloatType::get(op->getContext())); + resultElementType = resultType.getElementType(); } else { int64_t dtypeInt; if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 1b61f75703f6..cef844658258 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -354,3 +354,20 @@ func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torc %0 = torch.aten.transpose.int %arg0, %int0, %int1 : !torch.vtensor<[4,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32> return %0 : !torch.vtensor<[3,4],f32> } + + +// ----- +// CHECK-LABEL: func.func @torch.aten.empty.memory_format$noneDtype() +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<200x200x26xf64> +// CHECK: %[[CAST:.*]] = tensor.cast %[[EMPTY]] : tensor<200x200x26xf64> to tensor<200x200x26xf64> +// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<200x200x26xf64> -> !torch.vtensor<[200,200,26],f64> +// CHECK: return %[[RES]] : !torch.vtensor<[200,200,26],f64> +func.func @torch.aten.empty.memory_format$noneDtype() -> !torch.vtensor<[200,200,26],f64> attributes {torch.assume_strict_symbolic_shapes} { + %int200 = torch.constant.int 200 + %int26 = torch.constant.int 26 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int200, %int200, %int26 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.empty.memory_format %0, %none, %none, %none, %false, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64> + return %1 : !torch.vtensor<[200,200,26],f64> +}