diff --git a/src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp b/src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp index 8f26a96ea82..6aeed66c926 100644 --- a/src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp +++ b/src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp @@ -573,31 +573,16 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch, if (dt.isInt4()) return "[FO]"; return nullptr; }); - - if (fpmath_bf16 - && (utils::one_of(Type::f32, problem_.Ta, problem_.Tb) - || (problem_.Ta.isF8() || problem_.Tb.isF8())) - && (problem_.Ta.isInteger() || problem_.Tb.isInteger())) { - if (problem_.Ta.isInt8() || problem_.Ta.isInt4()) { - match_params.emplace_back(match_params[0]); - match_params.back().selector.precisions[1] = "B"; - } else { - match_params.emplace_back(match_params[0]); - match_params.back().selector.precisions[0] = "B"; - } - } - - if (fpmath_f16 - && (utils::one_of(Type::f32, problem_.Ta, problem_.Tb) - || (problem_.Ta.isF8() || problem_.Tb.isF8())) + if (((problem_.Ta.isF8() || problem_.Tb.isF8())) && (problem_.Ta.isInteger() || problem_.Tb.isInteger())) { - if (problem_.Ta.isInt8() || problem_.Ta.isInt4()) { - match_params.emplace_back(match_params[0]); - match_params.back().selector.precisions[1] = "H"; - } else { - match_params.emplace_back(match_params[0]); - match_params.back().selector.precisions[0] = "H"; - } + add_mode_matches(fpmath_bf16, [](Type dt) -> const char * { + if (dt.isF8()) return "B"; + return nullptr; + }); + add_mode_matches(fpmath_f16, [](Type dt) -> const char * { + if (dt.isF8()) return "H"; + return nullptr; + }); } if (fpmath_strict) { diff --git a/src/gpu/intel/jit/gemm/selector/db/kernel.db b/src/gpu/intel/jit/gemm/selector/db/kernel.db index ea3a1572adb..462acec40fe 100644 --- a/src/gpu/intel/jit/gemm/selector/db/kernel.db +++ b/src/gpu/intel/jit/gemm/selector/db/kernel.db @@ -881,6 +881,7 @@ auto _CATALOG_ = kcatalog::toFlatCatalog({ {{'F', "gemm", {"O", "S", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {16, -1, -1}, {1, 1, 1}, ""}, "am8x2+m32@8 aB8x2+m8@8 aB wg 1x4x16 kr kc8 nse li pt sr sb256 bk0 kv afb l2d", {16, (LoopType) 255, 128, {(LoopType) 224, (LoopType) 255, (LoopType) 2}, {262144, 262144, 16777216}, {262144, 262144, 32}, {16, 16, 8}, {1, 4, 16}, 1, (WGType) 1, 413, 0, 4096, {4, 4, 4}, {true, true, true}}, {'E', 17, {1.13268e+06, -103657, 246.583, 142575, 3.21126e+06, 0, 1.60336, 0.933443, 0.504957, 0.918636, 0.0712742, 0.0670609, 0.015172, 0.992002, 1.14631, 0.0718492, 1.8422e-11}}}, {{'F', "gemm", {"O", "S", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, 4, -1}, {1, 1, 1}, ""}, "at16x2+m16@32 at16+m32@32 aB wg 16x1x2 kr kc16 nse nmk li pt sr sb256 bk0 sm grf256 kv afb l4 l2d", {16, (LoopType) 255, 256, {(LoopType) 225, (LoopType) 255, (LoopType) 2}, {262144, 65536, 16777216}, {262144, 65536, 32}, {16, 4, 16}, {16, 1, 2}, 1, (WGType) 1, 413, 0, 4096, {4, 4, 4}, {true, true, true}}, {'E', 17, {1.1734e+06, -264907, -109870, 485064, 2.21266e+06, 0, 0.856653, 15.807, 1.98085, 3.89882, 0.125049, 0.0139237, 0.143865, 1, 1.34573, 0.978713, 4.7619e-12}}}, {{'F', "gemm", {"O", "S", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, 16, -1}, {1, 1, 1}, ""}, "at8x2+m16@24 at8x2+m32@8 aB wg 16x1x4 kr kc8 nse nmk li pt sr sb256 bk0 sm sn kv afb l2d", {16, (LoopType) 255, 128, {(LoopType) 225, (LoopType) 255, (LoopType) 2}, {262144, 262144, 16777216}, {262144, 262144, 32}, {16, 16, 8}, {16, 1, 4}, 1, (WGType) 1, 413, 0, 16384, {4, 4, 4}, {true, true, true}}, {'E', 17, {1.18993e+06, -230103, -26635.9, 388995, 2.2528e+06, 0, 0.900793, 5.78162, 0.552809, 1.28255, 0.0627307, 0.0602325, 0.0232779, 1, 1.21284, 0.921396, 2.8065e-12}}}, +{{'F', "gemm", {"O", "S", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "xy"}, "sB4 sB4 aB wg 4x8 kc4 cab4 ks8 nse bo sr bk0 sm sn l4", {8, (LoopType) 0, 128, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {524288, 262144, 16777216}, {524288, 262144, 16777216}, {32, 16, 8}, {4, 8, 1}, 1, (WGType) 1, 257, 32768, 0, {1, 2, 4}, {false, false, true}}, {'W', 1, {512}}}, {{'F', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {8, 8, 1}, "ABIp"}, "av32+m32@48 am32+S32@64 aB wg 4x8 xaf st vav hi pt sb32 bk0 sn grf256 sys sr br kv afb rr", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {524288, 262144, 16777216}, {524288, 262144, 32}, {32, 16, 32}, {4, 8, 1}, 1, (WGType) 1, 441, 0, 0, {8, 8, 4}, {true, true, true}}, {'E', 17, {869760, 741196, 0, 0, 8.192e+06, 1.05431e+07, 0.731287, 0.777012, 0.881104, 1.51408, 0.00403024, 0.00403024, 0, 0.998184, 1.76752, 1.28618, 2.08947e-12}}}, {{'F', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ip"}, "aB32 aB32 aB wg 8x4 cab3 ks32 af vav hi pt sr br bk0 sn nb 8x4 dm grf256 sys kv afb l4", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {262144, 524288, 16777216}, {262144, 524288, 32}, {16, 32, 32}, {8, 4, 1}, 1, (WGType) 1, 441, 49152, 0, {1, 1, 4}, {true, true, true}}, {'E', 17, {1.07153e+06, 922406, 0, 0, 5.48536e+06, 9.18323e+06, 0.85293, 1.19559, 1.04485, 1.64281, 0.00471518, 0.00471518, 0, 0.961495, 1.72318, 1.24713, 3.69593e-12}}}, {{'F', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {8, 8, 1}, "ABI"}, "av32 am16+m16@16 aB wg 2x4x4 kr ca3x2 ks32 af vav hi pt sr br bk0 sn grf256 kv afb sys", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 2}, {524288, 262144, 16777216}, {524288, 262144, 32}, {32, 16, 32}, {2, 4, 4}, 1, (WGType) 1, 445, 6144, 16384, {8, 8, 4}, {true, true, true}}, {'E', 17, {1.27954e+06, -187821, -42333.9, 291644, 3.34234e+06, 2.63782e+06, 0.670967, 0.826166, 0.942564, 1.64083, 0.0148244, 0.00555253, 0.00975056, 0.806514, 1.26716, 0.788997, 1.48059e-11}}},