diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 72217e5f4afd..3a5a5a7447c8 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -335,6 +335,10 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, if (auto dtypeFloat = dyn_cast(dtype)) { if (auto scalarFloat = dyn_cast(scalarType)) { + if (scalarFloat.getWidth() == 16 && dtypeFloat.getWidth() == 16) { + auto scalarF32 = b.create(loc, b.getF32Type(), scalar); + return b.create(loc, dtype, scalarF32); + } if (scalarFloat.getWidth() > dtypeFloat.getWidth()) return b.create(loc, dtype, scalar); // Only scalarFloat width < dtypeFloat width can reach here. diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index aa2be74f5d7e..c8fdeded44df 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -102,3 +102,19 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 %0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } + +// ----- + +// CHECK-LABEL: func.func @elementwise_todtype_bf162f16( +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK-SAME: bf16 to f32 +// CHECK: arith.truncf +// CHECK-SAME: f32 to f16 +func.func @elementwise_todtype_bf162f16(%arg0: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> { + %int5 = torch.constant.int 5 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,?,32,128],f16> + return %0 : !torch.vtensor<[1,?,32,128],f16> +}