Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Fix Conflict by Replacing local_memory with shared_local_memory in sdpa_opt_finalization_stage #29083

Merged
merged 2 commits into from
Feb 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 50 additions & 153 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -1581,6 +1581,8 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
// max_logits [batch, heads_num, q_len, partition_idx]
// tmp_out [batch, heads_num, q_len, partition_idx, head_size]

#define MAX_PARTITIONS_NUM 128

REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
KERNEL(sdpa_opt_finalization_stage)(
OPTIONAL_SHAPE_INFO_ARG
Expand All @@ -1603,163 +1605,58 @@ KERNEL(sdpa_opt_finalization_stage)(
__global SOFTMAX_ACCUMULATOR_TYPE* cur_exp_sums = exp_sums + offset;
__global SOFTMAX_ACCUMULATOR_TYPE* cur_max_logits = max_logits + offset;
__local SOFTMAX_ACCUMULATOR_TYPE tmp_slm[SUBGROUP_SIZE];
__local SOFTMAX_ACCUMULATOR_TYPE max_logits_u_exp_sum[MAX_PARTITIONS_NUM];

if (num_of_partitions <= SUBGROUP_SIZE * REG_VERSION_MAX_VALUES_PER_WI_LOWER) {
/* Registers kernel version, can handle up to SEQ_LEN_PARTITION_SIZE(256) * SUBGROUP_SIZE(16) * REG_VERSION_MAX_VALUES_PER_WI_LOWER(8/16) = 32768/65536 tokens */

SOFTMAX_ACCUMULATOR_TYPE max_logits_u_exp_sum[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_MIN};
SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;
const uint reduce_offset = HEAD_SIZE / SUBGROUP_SIZE > SUBGROUP_SIZE ? SUBGROUP_SIZE * SUBGROUP_SIZE : HEAD_SIZE;
for (uint i = local_id; i < num_of_partitions; i+= reduce_offset) {
max_logits_u_exp_sum[i] = cur_max_logits[i];
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits_u_exp_sum[i]);
}
local_max_logit = sub_group_reduce_max(local_max_logit);
if (sglid == 0) {
tmp_slm[sgid] = local_max_logit;
}
barrier(CLK_LOCAL_MEM_FENCE);

if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_max_logit = tmp_slm[sglid];
}
SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit);

// Update exp_sum with respect to the global maximum
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = local_id; i < num_of_partitions; i+= reduce_offset) {
SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = cur_exp_sums[i] * native_exp(max_logits_u_exp_sum[i] - global_max);
max_logits_u_exp_sum[i] = exp_sum_new;
local_exp_sum += exp_sum_new;
}
local_exp_sum = sub_group_reduce_add(local_exp_sum);
if (sglid == 0) {
tmp_slm[sgid] = local_exp_sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
local_exp_sum = 0;
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_exp_sum = tmp_slm[sglid];
}

SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = sub_group_reduce_add(local_exp_sum);
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
target_seq_idx * (num_of_partitions * HEAD_SIZE) +
partition_idx * (HEAD_SIZE) + local_id;
OUTPUT_TYPE out_val = tmp_out[tmp_out_offset];
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * TO_SOFTMAX_ACCUMULATOR_TYPE(max_logits_u_exp_sum[partition_idx]);
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
local_id;

output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(global_exp_sum);
} else if (num_of_partitions <= SUBGROUP_SIZE * REG_VERSION_MAX_VALUES_PER_WI) {
/* Registers kernel version, can handle up to SEQ_LEN_PARTITION_SIZE(256) * SUBGROUP_SIZE(16) * REG_VERSION_MAX_VALUES_PER_WI(24/48) = 98304/196608 tokens */
SOFTMAX_ACCUMULATOR_TYPE max_logits_u_exp_sum[REG_VERSION_MAX_VALUES_PER_WI] = {SOFTMAX_ACCUMULATOR_VAL_MIN};
SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;
const uint reduce_offset = HEAD_SIZE / SUBGROUP_SIZE > SUBGROUP_SIZE ? SUBGROUP_SIZE * SUBGROUP_SIZE : HEAD_SIZE;
for (uint i = local_id; i < num_of_partitions; i+= reduce_offset) {
max_logits_u_exp_sum[i] = cur_max_logits[i];
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits_u_exp_sum[i]);
}
local_max_logit = sub_group_reduce_max(local_max_logit);
if (sglid == 0) {
tmp_slm[sgid] = local_max_logit;
}
barrier(CLK_LOCAL_MEM_FENCE);

if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_max_logit = tmp_slm[sglid];
}
SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit);

// Update exp_sum with respect to the global maximum
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = local_id; i < num_of_partitions; i+= reduce_offset) {
SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = cur_exp_sums[i] * native_exp(max_logits_u_exp_sum[i] - global_max);
max_logits_u_exp_sum[i] = exp_sum_new;
local_exp_sum += exp_sum_new;
}
local_exp_sum = sub_group_reduce_add(local_exp_sum);
if (sglid == 0) {
tmp_slm[sgid] = local_exp_sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
local_exp_sum = 0;
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_exp_sum = tmp_slm[sglid];
}

SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = sub_group_reduce_add(local_exp_sum);
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
target_seq_idx * (num_of_partitions * HEAD_SIZE) +
partition_idx * (HEAD_SIZE) + local_id;
OUTPUT_TYPE out_val = tmp_out[tmp_out_offset];
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * TO_SOFTMAX_ACCUMULATOR_TYPE(max_logits_u_exp_sum[partition_idx]);
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
local_id;

output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(global_exp_sum);
} else {
/* Global memory kernel version, can handle any number of tokens, but could be very slow. */
SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;
const uint reduce_offset = HEAD_SIZE / SUBGROUP_SIZE > SUBGROUP_SIZE ? SUBGROUP_SIZE * SUBGROUP_SIZE : HEAD_SIZE;
for (uint i = local_id; i < num_of_partitions; i+= reduce_offset) {
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, cur_max_logits[i]);
}
local_max_logit = sub_group_reduce_max(local_max_logit);
if (sglid == 0) {
tmp_slm[sgid] = local_max_logit;
}
barrier(CLK_LOCAL_MEM_FENCE);

if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_max_logit = tmp_slm[sglid];
}
SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit);
// Update exp_sum with respect to the global maximum
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = local_id; i < num_of_partitions; i+= reduce_offset) {
local_exp_sum += cur_exp_sums[i] * native_exp(cur_max_logits[i] - global_max);
}
local_exp_sum = sub_group_reduce_add(local_exp_sum);
if (sglid == 0) {
tmp_slm[sgid] = local_exp_sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
local_exp_sum = 0;
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_exp_sum = tmp_slm[sglid];
}
SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;
const uint reduce_offset = HEAD_SIZE / SUBGROUP_SIZE > SUBGROUP_SIZE ? SUBGROUP_SIZE * SUBGROUP_SIZE : HEAD_SIZE;
for (uint i = local_id; i < num_of_partitions; i+= reduce_offset) {
max_logits_u_exp_sum[i] = cur_max_logits[i];
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits_u_exp_sum[i]);
}
local_max_logit = sub_group_reduce_max(local_max_logit);
if (sglid == 0) {
tmp_slm[sgid] = local_max_logit;
}
barrier(CLK_LOCAL_MEM_FENCE);

SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = sub_group_reduce_add(local_exp_sum);
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
target_seq_idx * (num_of_partitions * HEAD_SIZE) +
partition_idx * (HEAD_SIZE) + local_id;
OUTPUT_TYPE out_val = tmp_out[tmp_out_offset];
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * TO_SOFTMAX_ACCUMULATOR_TYPE(cur_exp_sums[partition_idx] * native_exp(cur_max_logits[partition_idx] - global_exp_sum));
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
local_id;
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_max_logit = tmp_slm[sglid];
}
SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit);

// Update exp_sum with respect to the global maximum
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = local_id; i < num_of_partitions; i+= reduce_offset) {
SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = cur_exp_sums[i] * native_exp(max_logits_u_exp_sum[i] - global_max);
max_logits_u_exp_sum[i] = exp_sum_new;
local_exp_sum += exp_sum_new;
}
local_exp_sum = sub_group_reduce_add(local_exp_sum);
if (sglid == 0) {
tmp_slm[sgid] = local_exp_sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
local_exp_sum = 0;
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_exp_sum = tmp_slm[sglid];
}

output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(global_exp_sum);
SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = sub_group_reduce_add(local_exp_sum);
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
target_seq_idx * (num_of_partitions * HEAD_SIZE) +
partition_idx * (HEAD_SIZE) + local_id;
OUTPUT_TYPE out_val = tmp_out[tmp_out_offset];
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * TO_SOFTMAX_ACCUMULATOR_TYPE(max_logits_u_exp_sum[partition_idx]);
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
local_id;

output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(global_exp_sum);
}

#endif
Loading