From bf1e4ca14621095922f544dfffc7af1722965d85 Mon Sep 17 00:00:00 2001 From: "Li, Tingqian" Date: Sun, 9 Feb 2025 22:50:38 -0800 Subject: [PATCH] fix qkv/mlp win32 support --- .../src/nodes/kernels/x64/mlp_kernel.cpp | 54 +++++++++++++++---- .../src/nodes/kernels/x64/mlp_kernel.hpp | 31 ++++++++--- src/plugins/intel_cpu/src/nodes/qkv_proj.cpp | 7 +-- 3 files changed, 68 insertions(+), 24 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp index be92579674b957..5be1d5cbcd7cf9 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp @@ -22,8 +22,15 @@ void MKernel::generate_2x2() { Xbyak::Reg64 reg_A_addr = abi_param2; Xbyak::Reg64 reg_A_stride = abi_param3; Xbyak::Reg64 reg_B_addr = abi_param4; +#ifdef _WIN32 + Xbyak::Reg64 reg_C_addr = rdi; + Xbyak::Reg64 reg_C_stride = rsi; + push(rdi); + push(rsi); +#else Xbyak::Reg64 reg_C_addr = abi_param5; Xbyak::Reg64 reg_C_stride = abi_param6; +#endif Xbyak::Reg64 reg_ktiles = rax; Xbyak::Reg64 reg_B_stride = r10; @@ -140,6 +147,10 @@ void MKernel::generate_2x2() { tilestored(ptr[reg_C_addr + reg_C_stride + 64], tmmC11); pop(reg_prefetch); +#ifdef _WIN32 + pop(rsi); + pop(rdi); +#endif ret(); } @@ -174,8 +185,15 @@ void MKernel::generate_1x2() { Xbyak::Reg64 reg_A_addr = abi_param2; Xbyak::Reg64 reg_A_stride = abi_param3; Xbyak::Reg64 reg_B_addr = abi_param4; +#ifdef _WIN32 + Xbyak::Reg64 reg_C_addr = rdi; + Xbyak::Reg64 reg_C_stride = rsi; + push(rdi); + push(rsi); +#else Xbyak::Reg64 reg_C_addr = abi_param5; Xbyak::Reg64 reg_C_stride = abi_param6; +#endif Xbyak::Reg64 reg_ktiles = rax; Xbyak::Reg64 reg_B_stride = r10; @@ -254,7 +272,10 @@ void MKernel::generate_1x2() { tilestored(ptr[reg_C_addr + reg_C_stride + 64], tmmC01); } L(skip_store); - +#ifdef _WIN32 + pop(rsi); + pop(rdi); +#endif ret(); } @@ -600,11 +621,18 @@ void GateUpCombine::generate() { void ReduceAdd2bh::generate() { if (m_do_reduce2) { - Xbyak::Reg64 src0 = abi_param1; - Xbyak::Reg64 src1 = abi_param2; - Xbyak::Reg64 dst = abi_param3; - Xbyak::Reg64 prefetch_dst = abi_param4; - Xbyak::Reg64 BN = abi_param5; + Xbyak::Reg64 src0 = rdx; + Xbyak::Reg64 src1 = r8; + Xbyak::Reg64 dst = r9; + Xbyak::Reg64 prefetch_dst = r10; + Xbyak::Reg64 BN = r11; + + mov(src0, ptr[abi_param1 + offsetof(CallArgs, src0)]); + mov(src1, ptr[abi_param1 + offsetof(CallArgs, src1)]); + mov(dst, ptr[abi_param1 + offsetof(CallArgs, dst)]); + mov(prefetch_dst, ptr[abi_param1 + offsetof(CallArgs, prefetch_dst)]); + mov(BN, ptr[abi_param1 + offsetof(CallArgs, num_cols)]); + Xbyak::Reg64 loop_i = rax; Xbyak::Label loop_begin; @@ -636,10 +664,16 @@ void ReduceAdd2bh::generate() { ret(); } else { - Xbyak::Reg64 src0 = abi_param1; - Xbyak::Reg64 dst = abi_param2; - Xbyak::Reg64 prefetch_dst = abi_param3; - Xbyak::Reg64 BN = abi_param4; + Xbyak::Reg64 src0 = rdx; + Xbyak::Reg64 dst = r9; + Xbyak::Reg64 prefetch_dst = r10; + Xbyak::Reg64 BN = r11; + + mov(src0, ptr[abi_param1 + offsetof(CallArgs, src0)]); + mov(dst, ptr[abi_param1 + offsetof(CallArgs, dst)]); + mov(prefetch_dst, ptr[abi_param1 + offsetof(CallArgs, prefetch_dst)]); + mov(BN, ptr[abi_param1 + offsetof(CallArgs, num_cols)]); + Xbyak::Reg64 loop_i = rax; Xbyak::Label loop_begin; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp index fb8f5b152ff35a..fcc73f304f4225 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp @@ -527,26 +527,41 @@ class ReduceAdd2bh : public dnnl::impl::cpu::x64::jit_generator { void generate() override; + struct CallArgs { + float* src0; + float* src1; + int16_t * dst; + int16_t * prefetch_dst; + int64_t num_cols; + }; // add two float input eltwise and convert to bf16 : ConvertFP32toBF16(src0 + src1) void call(float* src0, float* src1, size_t src_stride, void* pf16_dst, size_t dst_stride, int num_rows, int num_cols) { - auto* dst = reinterpret_cast(pf16_dst); - for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) { + CallArgs args; + args.src0 = src0; + args.src1 = src1; + args.dst = reinterpret_cast(pf16_dst); + args.num_cols = num_cols; + for (int m = 0; m < num_rows; m++, args.src0 += src_stride, args.src1 += src_stride, args.dst += dst_stride) { // the prefetch distance is increased to ensure by the time store happens // prefetch has done and no HW prefetcher is triggered - auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst); - (*this)(src0, src1, dst, prefetch_dst, num_cols); + args.prefetch_dst = (m + 2 < num_rows) ? (args.dst + 2 * dst_stride) : (args.dst); + + (*this)(&args); } } // convert tensor to bf16: ConvertFP32toBF16(src0) void call(float* src0, size_t src_stride, void* pf16_dst, size_t dst_stride, int num_rows, int num_cols) { - auto* dst = reinterpret_cast(pf16_dst); - for (int m = 0; m < num_rows; m++, src0 += src_stride, dst += dst_stride) { + CallArgs args; + args.src0 = src0; + args.dst = reinterpret_cast(pf16_dst); + args.num_cols = num_cols; + for (int m = 0; m < num_rows; m++, args.src0 += src_stride, args.dst += dst_stride) { // the prefetch distance is increased to ensure by the time store happens // prefetch has done and no HW prefetcher is triggered - auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst); - (*this)(src0, dst, prefetch_dst, num_cols); + args.prefetch_dst = (m + 2 < num_rows) ? (args.dst + 2 * dst_stride) : (args.dst); + (*this)(&args); } } }; diff --git a/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp b/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp index 55dbc20c1bd946..64f1bd68d90f3f 100644 --- a/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp +++ b/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp @@ -293,12 +293,7 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase { asym); } // compress accumulation result into target - for (int mi = 0; mi < BM; mi++, src += stride_src, dst += stride_dst) { - // the prefetch distance is increased to ensure by the time store happens - // prefetch has done and no HW prefetcher is triggered - auto* prefetch_dst = (mi + 2 < BM) ? (dst + 2 * stride_dst) : (dst); - jit_cvt(src, dst, prefetch_dst, work.BN); - } + jit_cvt.call(src, stride_src, dst, stride_dst, BM, work.BN); } }); m += BM;