From d3fe084e63c4356eb4751e209117396451fda37c Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Thu, 16 Jan 2025 13:29:44 -0800 Subject: [PATCH 1/3] [Linalg] Add convert between bf16 and f16 --- lib/Conversion/Utils/Utils.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 72217e5f4afd..3a5a5a7447c8 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -335,6 +335,10 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, if (auto dtypeFloat = dyn_cast(dtype)) { if (auto scalarFloat = dyn_cast(scalarType)) { + if (scalarFloat.getWidth() == 16 && dtypeFloat.getWidth() == 16) { + auto scalarF32 = b.create(loc, b.getF32Type(), scalar); + return b.create(loc, dtype, scalarF32); + } if (scalarFloat.getWidth() > dtypeFloat.getWidth()) return b.create(loc, dtype, scalar); // Only scalarFloat width < dtypeFloat width can reach here. From 3a1153a5c73b2663ba749d11c6271adab357201e Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Thu, 16 Jan 2025 17:52:52 -0800 Subject: [PATCH 2/3] Add bf16 to f16 lit test --- .../Conversion/TorchToLinalg/elementwise.mlir | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index aa2be74f5d7e..d126f700c9a2 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -102,3 +102,37 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 %0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } + +// ----- + +// CHECK-LABEL: func.func @elementwise_todtype_bf162f16( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> { +// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,?,32,128],bf16> -> tensor<1x?x32x128xbf16> +// CHECK: %[[INT5:.*]] = torch.constant.int 5 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[CONSTANT1_1:.*]] = arith.constant 1 : index +// CHECK: %[[CONSTANT1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[INPUT]], %[[CONSTANT1]] : tensor<1x?x32x128xbf16> +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2 : index +// CHECK: %[[CONSTANT_32:.*]] = arith.constant 32 : index +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 3 : index +// CHECK: %[[CONSTANT_128:.*]] = arith.constant 128 : index +// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM]]) : tensor<1x?x32x128xf16> +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[INPUT]] : tensor<1x?x32x128xbf16>) outs(%[[EMPTY]] : tensor<1x?x32x128xf16>) { +// CHECK: ^bb0(%[[LHS:.*]]: bf16, %[[RHS:.*]]: f16): +// CHECK: %[[EXTF:.*]] = arith.extf %[[LHS]] : bf16 to f32 +// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[EXTF]] : f32 to f16 +// CHECK: linalg.yield %[[TRUNCF]] : f16 +// CHECK: } -> tensor<1x?x32x128xf16> +// CHECK: %[[CAST:.*]] = tensor.cast %[[GENERIC]] : tensor<1x?x32x128xf16> to tensor<1x?x32x128xf16> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<1x?x32x128xf16> -> !torch.vtensor<[1,?,32,128],f16> +// CHECK: return %[[RESULT]] : !torch.vtensor<[1,?,32,128],f16> +// CHECK: } +func.func @elementwise_todtype_bf162f16(%arg0: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> { + %int5 = torch.constant.int 5 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,?,32,128],f16> + return %0 : !torch.vtensor<[1,?,32,128],f16> +} From 644448170605072d4fb6d072208559351b6b53b7 Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Fri, 17 Jan 2025 13:45:02 -0800 Subject: [PATCH 3/3] Structure the lit test same as elementwise_sinh --- .../Conversion/TorchToLinalg/elementwise.mlir | 28 ++++--------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index d126f700c9a2..c8fdeded44df 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -106,29 +106,11 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 // ----- // CHECK-LABEL: func.func @elementwise_todtype_bf162f16( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> { -// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,?,32,128],bf16> -> tensor<1x?x32x128xbf16> -// CHECK: %[[INT5:.*]] = torch.constant.int 5 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[CONSTANT1_1:.*]] = arith.constant 1 : index -// CHECK: %[[CONSTANT1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM:.*]] = tensor.dim %[[INPUT]], %[[CONSTANT1]] : tensor<1x?x32x128xbf16> -// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2 : index -// CHECK: %[[CONSTANT_32:.*]] = arith.constant 32 : index -// CHECK: %[[CONSTANT_3:.*]] = arith.constant 3 : index -// CHECK: %[[CONSTANT_128:.*]] = arith.constant 128 : index -// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM]]) : tensor<1x?x32x128xf16> -// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[INPUT]] : tensor<1x?x32x128xbf16>) outs(%[[EMPTY]] : tensor<1x?x32x128xf16>) { -// CHECK: ^bb0(%[[LHS:.*]]: bf16, %[[RHS:.*]]: f16): -// CHECK: %[[EXTF:.*]] = arith.extf %[[LHS]] : bf16 to f32 -// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[EXTF]] : f32 to f16 -// CHECK: linalg.yield %[[TRUNCF]] : f16 -// CHECK: } -> tensor<1x?x32x128xf16> -// CHECK: %[[CAST:.*]] = tensor.cast %[[GENERIC]] : tensor<1x?x32x128xf16> to tensor<1x?x32x128xf16> -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<1x?x32x128xf16> -> !torch.vtensor<[1,?,32,128],f16> -// CHECK: return %[[RESULT]] : !torch.vtensor<[1,?,32,128],f16> -// CHECK: } +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK-SAME: bf16 to f32 +// CHECK: arith.truncf +// CHECK-SAME: f32 to f16 func.func @elementwise_todtype_bf162f16(%arg0: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> { %int5 = torch.constant.int 5 %false = torch.constant.bool false