diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp index 72c02900f904e7..cd38fc5bd700e4 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp @@ -249,12 +249,12 @@ struct scaled_dot_product_attention_impl : multi_stage_primitiveinput_v_transpose_order); OPENVINO_ASSERT(key_shape == value_shape, "[GPU] The shapes of key and value inputs are expected to be equal"); - for (size_t i = 0; i < query_shape.size(); ++i) { - if (query_shape[i].is_static() && key_shape[i].is_static() && value_shape[i].is_static()) { - if (query_shape[i].get_length() > key_shape[i].get_length()) { - config.broadcast_axis = desc->input_k_transpose_order[i]; - config.group_size = query_shape[i].get_length() / key_shape[i].get_length(); - } + + const auto num_heads_dim = 1; + if (query_shape[num_heads_dim].is_static() && key_shape[num_heads_dim].is_static() && value_shape[num_heads_dim].is_static()) { + if (query_shape[num_heads_dim].get_length() > key_shape[num_heads_dim].get_length()) { + config.broadcast_axis = desc->input_k_transpose_order[num_heads_dim]; + config.group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length(); } } diff --git a/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp index 89b3d38f5051d3..6ad7efea69befc 100644 --- a/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp +++ b/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp @@ -310,7 +310,26 @@ const std::vector> shapes{ {ov::test::InputShape{ov::PartialShape{-1, 1, -1, -1}, {ov::Shape{1, 1, 7, 7}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 10, 10}}} }, - }, + } +}; + +const std::vector> disable_transpose{}; +const std::vector> transpose_value{{0, 1, 2, 3}, {0, 1, 2, 3}, {0, 2, 1, 3}}; +const std::vector> transpose_all{{0, 2, 1, 3}, {0, 2, 1, 3}, {0, 2, 1, 3}}; + +const auto dynamic_shape_params = testing::Combine(testing::Values(ov::element::f16 /*, ov::element::f32 */), + testing::ValuesIn(shapes), + testing::Values(true, false), + testing::Values(true, false), + testing::Values(true, false), + testing::ValuesIn({disable_transpose, transpose_value})); + +INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttn_GPU, + ScaledAttnLayerGPUTest, + dynamic_shape_params, + ScaledAttnLayerGPUTest::getTestCaseName); + +const std::vector> static_shapes{ // static shapes { // q shape @@ -326,21 +345,32 @@ const std::vector> shapes{ {ov::Shape{1, 1, 100, 100}}} }, }, + { + // q shape + {ov::test::InputShape{ov::PartialShape{1, 8, 64, 128}, + {ov::Shape{1, 8, 64, 128}}} + }, + // kv shape + {ov::test::InputShape{ov::PartialShape{1, 8, 13, 128}, + {ov::Shape{1, 8, 13, 128}}} + }, + // attn shape: [B, 1, -1, L0+L1] + {ov::test::InputShape{ov::PartialShape{1, 1, 64, 13}, + {ov::Shape{1, 1, 64, 13}}} + }, + }, }; -const std::vector> disable_transpose{}; -const std::vector> enable_transpose{{0, 1, 2, 3}, {0, 1, 2, 3}, {0, 2, 1, 3}}; +const auto static_shape_params = testing::Combine(testing::Values(ov::element::f16), + testing::ValuesIn(static_shapes), + testing::Values(true, false), + testing::Values(true, false), + testing::Values(true, false), + testing::ValuesIn({disable_transpose, transpose_all})); -const auto params = testing::Combine(testing::Values(ov::element::f16 /*, ov::element::f32 */), - testing::ValuesIn(shapes), - testing::Values(true, false), - testing::Values(true, false), - testing::Values(true, false), - testing::ValuesIn({disable_transpose, enable_transpose})); - -INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttn_GPU, +INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttnStatic_GPU, ScaledAttnLayerGPUTest, - params, + static_shape_params, ScaledAttnLayerGPUTest::getTestCaseName); } // namespace