Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] fix qkv_proj/mlp jit kernel's win32 support #28915

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,15 @@ void MKernel::generate_2x2() {
Xbyak::Reg64 reg_A_addr = abi_param2;
Xbyak::Reg64 reg_A_stride = abi_param3;
Xbyak::Reg64 reg_B_addr = abi_param4;
#ifdef _WIN32
Xbyak::Reg64 reg_C_addr = rdi;
Xbyak::Reg64 reg_C_stride = rsi;
push(rdi);
push(rsi);
#else
Xbyak::Reg64 reg_C_addr = abi_param5;
Xbyak::Reg64 reg_C_stride = abi_param6;
#endif

Xbyak::Reg64 reg_ktiles = rax;
Xbyak::Reg64 reg_B_stride = r10;
Expand Down Expand Up @@ -140,6 +147,10 @@ void MKernel::generate_2x2() {
tilestored(ptr[reg_C_addr + reg_C_stride + 64], tmmC11);

pop(reg_prefetch);
#ifdef _WIN32
pop(rsi);
pop(rdi);
#endif
ret();
}

Expand Down Expand Up @@ -174,8 +185,15 @@ void MKernel::generate_1x2() {
Xbyak::Reg64 reg_A_addr = abi_param2;
Xbyak::Reg64 reg_A_stride = abi_param3;
Xbyak::Reg64 reg_B_addr = abi_param4;
#ifdef _WIN32
Xbyak::Reg64 reg_C_addr = rdi;
Xbyak::Reg64 reg_C_stride = rsi;
push(rdi);
push(rsi);
#else
Xbyak::Reg64 reg_C_addr = abi_param5;
Xbyak::Reg64 reg_C_stride = abi_param6;
#endif

Xbyak::Reg64 reg_ktiles = rax;
Xbyak::Reg64 reg_B_stride = r10;
Expand Down Expand Up @@ -254,7 +272,10 @@ void MKernel::generate_1x2() {
tilestored(ptr[reg_C_addr + reg_C_stride + 64], tmmC01);
}
L(skip_store);

#ifdef _WIN32
pop(rsi);
pop(rdi);
#endif
ret();
}

Expand Down Expand Up @@ -600,11 +621,18 @@ void GateUpCombine::generate() {

void ReduceAdd2bh::generate() {
if (m_do_reduce2) {
Xbyak::Reg64 src0 = abi_param1;
Xbyak::Reg64 src1 = abi_param2;
Xbyak::Reg64 dst = abi_param3;
Xbyak::Reg64 prefetch_dst = abi_param4;
Xbyak::Reg64 BN = abi_param5;
Xbyak::Reg64 src0 = rdx;
Xbyak::Reg64 src1 = r8;
Xbyak::Reg64 dst = r9;
Xbyak::Reg64 prefetch_dst = r10;
Xbyak::Reg64 BN = r11;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for note: from doc: rax, rcx, rdx, r9-r11 could be used in the callee without protection.


mov(src0, ptr[abi_param1 + offsetof(CallArgs, src0)]);
mov(src1, ptr[abi_param1 + offsetof(CallArgs, src1)]);
mov(dst, ptr[abi_param1 + offsetof(CallArgs, dst)]);
mov(prefetch_dst, ptr[abi_param1 + offsetof(CallArgs, prefetch_dst)]);
mov(BN, ptr[abi_param1 + offsetof(CallArgs, num_cols)]);

Xbyak::Reg64 loop_i = rax;

Xbyak::Label loop_begin;
Expand Down Expand Up @@ -636,10 +664,16 @@ void ReduceAdd2bh::generate() {

ret();
} else {
Xbyak::Reg64 src0 = abi_param1;
Xbyak::Reg64 dst = abi_param2;
Xbyak::Reg64 prefetch_dst = abi_param3;
Xbyak::Reg64 BN = abi_param4;
Xbyak::Reg64 src0 = rdx;
Xbyak::Reg64 dst = r9;
Xbyak::Reg64 prefetch_dst = r10;
Xbyak::Reg64 BN = r11;

mov(src0, ptr[abi_param1 + offsetof(CallArgs, src0)]);
mov(dst, ptr[abi_param1 + offsetof(CallArgs, dst)]);
mov(prefetch_dst, ptr[abi_param1 + offsetof(CallArgs, prefetch_dst)]);
mov(BN, ptr[abi_param1 + offsetof(CallArgs, num_cols)]);

Xbyak::Reg64 loop_i = rax;

Xbyak::Label loop_begin;
Expand Down
31 changes: 23 additions & 8 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,26 +527,41 @@ class ReduceAdd2bh : public dnnl::impl::cpu::x64::jit_generator {

void generate() override;

struct CallArgs {
float* src0;
float* src1;
int16_t * dst;
int16_t * prefetch_dst;
int64_t num_cols;
};
// add two float input eltwise and convert to bf16 : ConvertFP32toBF16(src0 + src1)
void
call(float* src0, float* src1, size_t src_stride, void* pf16_dst, size_t dst_stride, int num_rows, int num_cols) {
auto* dst = reinterpret_cast<int16_t*>(pf16_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) {
CallArgs args;
args.src0 = src0;
args.src1 = src1;
args.dst = reinterpret_cast<int16_t*>(pf16_dst);
args.num_cols = num_cols;
for (int m = 0; m < num_rows; m++, args.src0 += src_stride, args.src1 += src_stride, args.dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, src1, dst, prefetch_dst, num_cols);
args.prefetch_dst = (m + 2 < num_rows) ? (args.dst + 2 * dst_stride) : (args.dst);

(*this)(&args);
}
}

// convert tensor to bf16: ConvertFP32toBF16(src0)
void call(float* src0, size_t src_stride, void* pf16_dst, size_t dst_stride, int num_rows, int num_cols) {
auto* dst = reinterpret_cast<int16_t*>(pf16_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, dst += dst_stride) {
CallArgs args;
args.src0 = src0;
args.dst = reinterpret_cast<int16_t*>(pf16_dst);
args.num_cols = num_cols;
for (int m = 0; m < num_rows; m++, args.src0 += src_stride, args.dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, dst, prefetch_dst, num_cols);
args.prefetch_dst = (m + 2 < num_rows) ? (args.dst + 2 * dst_stride) : (args.dst);
(*this)(&args);
}
}
};
Expand Down
7 changes: 1 addition & 6 deletions src/plugins/intel_cpu/src/nodes/qkv_proj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,7 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase {
asym);
}
// compress accumulation result into target
for (int mi = 0; mi < BM; mi++, src += stride_src, dst += stride_dst) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (mi + 2 < BM) ? (dst + 2 * stride_dst) : (dst);
jit_cvt(src, dst, prefetch_dst, work.BN);
}
jit_cvt.call(src, stride_src, dst, stride_dst, BM, work.BN);
}
});
m += BM;
Expand Down
Loading