diff --git a/Jenkinsfile b/Jenkinsfile index 132257ad80..48b4c805cd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -735,11 +735,11 @@ def process_results(Map conf=[:]){ //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true - 0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true + 0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true;RUN_CODEGEN_TESTS=true 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false - 0 13 * * * % BUILD_LEGACY_OS=true ''' : "" + 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false + 0 13 * * * % BUILD_LEGACY_OS=true''' : "" pipeline { agent none @@ -806,6 +806,10 @@ pipeline { name: "RUN_GROUPED_CONV_LARGE_CASES_TESTS", defaultValue: false, description: "Run the grouped conv large cases tests (default: OFF)") + booleanParam( + name: "RUN_CODEGEN_TESTS", + defaultValue: false, + description: "Run codegen tests (default: OFF)") booleanParam( name: "RUN_CK_TILE_FMHA_TESTS", defaultValue: false, @@ -926,7 +930,30 @@ pipeline { execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ make -j64 test_grouped_convnd_fwd_large_cases_xdl && \ ./bin/test_grouped_convnd_fwd_large_cases_xdl""" - } + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } + stage("Run Codegen Tests") + { + parallel + { + stage("Run Codegen Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CODEGEN_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a")} + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake ../codegen && \ + make -j64 check""" + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() @@ -951,7 +978,7 @@ pipeline { make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ cd ../ && example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ - } + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() @@ -970,7 +997,7 @@ pipeline { make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ cd ../ && example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ - } + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() @@ -995,7 +1022,7 @@ pipeline { make -j64 tile_example_gemm_basic && \ cd ../ && example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ - } + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() @@ -1014,7 +1041,7 @@ pipeline { make -j64 tile_example_gemm_basic && \ cd ../ && example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ - } + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() @@ -1040,7 +1067,7 @@ pipeline { -DCMAKE_CXX_FLAGS=" -O3 " \ -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """ execute_args = " " - } + } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name) cleanWs() @@ -1059,7 +1086,7 @@ pipeline { -DCMAKE_CXX_FLAGS=" -O3 " \ -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """ execute_args = " " - } + } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name) cleanWs() @@ -1140,7 +1167,7 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \ -D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ - } + } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 2492804f28..1ca0d12821 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -1,3 +1,6 @@ +cmake_minimum_required(VERSION 3.16) +project(composable_kernel_host) + set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) @@ -5,56 +8,51 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) -add_compile_options(-std=c++17) -find_package(hip) -add_custom_target(codegen) +find_package(ROCM) +include(ROCMInstallTargets) +include(ROCMTest) -# add include directories -include_directories(BEFORE - ${PROJECT_BINARY_DIR}/include - ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/library/include - ${HIP_INCLUDE_DIRS} - ) +rocm_setup_version(VERSION 1.0) list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake) include(Embed) file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS - ${CK_ROOT}/include/ck/*.hpp) -#printouts fot debug purposes -#message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") -#message(STATUS "RELATIVE: ${CK_ROOT}/include") + ${CK_ROOT}/include/ck/*.hpp) +# printouts fot debug purposes +# message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") +# message(STATUS "RELATIVE: ${CK_ROOT}/include") add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) -file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) +add_compile_options(-std=c++17) -##message(STATUS "SOURCE_FILES: ${SOURCES}") +file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) target_link_libraries(ck_host PRIVATE ck_headers) -set_target_properties(ck_host PROPERTIES - LINKER_LANGUAGE CXX - POSITION_INDEPENDENT_CODE ON) +set_target_properties(ck_host PROPERTIES + LINKER_LANGUAGE CXX + POSITION_INDEPENDENT_CODE ON) -target_include_directories(ck_host PUBLIC - $ - $ -) +# target_include_directories(ck_host PUBLIC +# $ +# ) add_executable(ck-template-driver driver/main.cpp) target_link_libraries(ck-template-driver ck_host) -rocm_install( +rocm_install_targets( TARGETS ck_host ck_headers - EXPORT ck_hostTargets + EXPORT ck_host_targets + INCLUDE include + PRIVATE +) +rocm_export_targets( + EXPORT ck_host_targets + NAMESPACE composable_kernel:: ) -rocm_install(EXPORT ck_hostTargets - FILE composable_kernelck_hostTargets.cmake - NAMESPACE composable_kernel:: - DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel) -rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) if(BUILD_TESTING) - add_subdirectory(test) + add_subdirectory(test) endif() + diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt index 1de612e49a..48fde531da 100644 --- a/codegen/test/CMakeLists.txt +++ b/codegen/test/CMakeLists.txt @@ -1,23 +1,25 @@ list(APPEND CMAKE_PREFIX_PATH /opt/rocm) add_subdirectory(rtc) file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) -# do not build the tests when we build the library for various targets -if(NOT GPU_ARCHS) - foreach(TEST_SRC ${TEST_SRCS}) - set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP) - get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) - add_executable(codegen_test_${BASE_NAME} ${TEST_SRC}) - if(CK_USE_ALTERNATIVE_PYTHON) - target_link_options(codegen_test_${BASE_NAME} PRIVATE -lstdc++fs) - endif() - add_dependencies(codegen codegen_test_${BASE_NAME}) - add_dependencies(tests codegen_test_${BASE_NAME}) - add_dependencies(check codegen_test_${BASE_NAME}) - add_test(NAME codegen_test_${BASE_NAME} COMMAND codegen_test_${BASE_NAME}) - message("adding test codegen_test_${BASE_NAME}") - target_link_libraries(codegen_test_${BASE_NAME} ck_rtc ck_host) - target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/codegen/test/include) + +# TODO: These tests need to be refactored to remove dependency on main ck +# headers and device compilation. +set(TESTS_REQUIRE_DEVICE_COMPILE + grouped_conv_fwd_multiple_d_v1 + grouped_conv_fwd_multiple_d_v2 + grouped_conv_fwd_multiple_d_v3 + grouped_conv_fwd_multiple_d_v4 +) +find_package(hip) + +foreach(TEST_SRC ${TEST_SRCS}) + get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) + rocm_add_test_executable(codegen_test_${BASE_NAME} ${TEST_SRC}) + target_link_libraries(codegen_test_${BASE_NAME} ck_rtc ck_host) + target_include_directories(codegen_test_${BASE_NAME} PUBLIC include) + if(BASE_NAME IN_LIST TESTS_REQUIRE_DEVICE_COMPILE) + target_link_libraries(codegen_test_${BASE_NAME} hip::device) target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/include) target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/library/include) - endforeach() -endif() + endif() +endforeach() diff --git a/codegen/test/common.hpp b/codegen/test/include/common.hpp similarity index 100% rename from codegen/test/common.hpp rename to codegen/test/include/common.hpp diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt index 39497f1a21..68bfc2467b 100644 --- a/codegen/test/rtc/CMakeLists.txt +++ b/codegen/test/rtc/CMakeLists.txt @@ -1,4 +1,6 @@ +find_package(hip) file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp) add_library(ck_rtc ${RTC_SOURCES}) target_include_directories(ck_rtc PUBLIC include) target_link_libraries(ck_rtc PUBLIC hip::host) +target_link_libraries(ck_rtc PUBLIC -lstdc++fs) diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp index 71db7be249..c4413b47be 100644 --- a/codegen/test/rtc/include/rtc/compile_kernel.hpp +++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp @@ -2,14 +2,14 @@ #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL #include -#include +#include #include namespace rtc { struct src_file { - CK::fs::path path; + fs::path path; std::string_view content; }; diff --git a/codegen/test/rtc/include/rtc/filesystem.hpp b/codegen/test/rtc/include/rtc/filesystem.hpp new file mode 100644 index 0000000000..3b94b84b9f --- /dev/null +++ b/codegen/test/rtc/include/rtc/filesystem.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef GUARD_TEST_HOST_RTC_FILESYSTEM_HPP +#define GUARD_TEST_HOST_RTC_FILESYSTEM_HPP + +#include +#include + +// clang-format off +#if defined(CPPCHECK) + #define RTC_HAS_FILESYSTEM 1 + #define RTC_HAS_FILESYSTEM_TS 1 +#elif defined(_WIN32) + #if _MSC_VER >= 1920 + #define RTC_HAS_FILESYSTEM 1 + #define RTC_HAS_FILESYSTEM_TS 0 + #elif _MSC_VER >= 1900 + #define RTC_HAS_FILESYSTEM 0 + #define RTC_HAS_FILESYSTEM_TS 1 + #else + #define RTC_HAS_FILESYSTEM 0 + #define RTC_HAS_FILESYSTEM_TS 0 + #endif +#elif defined(__has_include) + #if __has_include() && __cplusplus >= 201703L + #define RTC_HAS_FILESYSTEM 1 + #else + #define RTC_HAS_FILESYSTEM 0 + #endif + #if __has_include() && __cplusplus >= 201103L + #define RTC_HAS_FILESYSTEM_TS 1 + #else + #define RTC_HAS_FILESYSTEM_TS 0 + #endif +#else + #define RTC_HAS_FILESYSTEM 0 + #define RTC_HAS_FILESYSTEM_TS 0 +#endif +// clang-format on + +#if RTC_HAS_FILESYSTEM +#include +#elif RTC_HAS_FILESYSTEM_TS +#include +#else +#error "No filesystem include available" +#endif + +namespace rtc { + +#if RTC_HAS_FILESYSTEM +namespace fs = ::std::filesystem; +#elif RTC_HAS_FILESYSTEM_TS +namespace fs = ::std::experimental::filesystem; +#endif + +} // namespace rtc + +#endif // GUARD_RTC_FILESYSTEM_HPP_ diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp index 0b4bf002c1..a0a2cb9b77 100644 --- a/codegen/test/rtc/include/rtc/tmp_dir.hpp +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -2,13 +2,13 @@ #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR #include -#include +#include namespace rtc { struct tmp_dir { - CK::fs::path path; + fs::path path; tmp_dir(const std::string& prefix = ""); void execute(const std::string& cmd) const; diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index cc1bb80c31..8cb71b9043 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -1,4 +1,4 @@ -#include "rtc/hip.hpp" +#include #include #include #include @@ -70,9 +70,9 @@ kernel compile_kernel(const std::vector& srcs, compile_options options for(const auto& src : srcs) { - CK::fs::path full_path = td.path / src.path; - CK::fs::path parent_path = full_path.parent_path(); - CK::fs::create_directories(parent_path); + fs::path full_path = td.path / src.path; + fs::path parent_path = full_path.parent_path(); + fs::create_directories(parent_path); write_string(full_path.string(), src.content); if(src.path.extension().string() == ".cpp") { @@ -86,7 +86,7 @@ kernel compile_kernel(const std::vector& srcs, compile_options options td.execute(compiler() + options.flags); auto out_path = td.path / out; - if(not CK::fs::exists(out_path)) + if(not fs::exists(out_path)) throw std::runtime_error("Output file missing: " + out); auto obj = read_buffer(out_path.string()); diff --git a/codegen/test/rtc/src/tmp_dir.cpp b/codegen/test/rtc/src/tmp_dir.cpp index 659bbbe13f..4e89bc3539 100644 --- a/codegen/test/rtc/src/tmp_dir.cpp +++ b/codegen/test/rtc/src/tmp_dir.cpp @@ -31,10 +31,10 @@ std::string unique_string(const std::string& prefix) } tmp_dir::tmp_dir(const std::string& prefix) - : path(CK::fs::temp_directory_path() / + : path(fs::temp_directory_path() / unique_string(prefix.empty() ? "ck-rtc" : "ck-rtc-" + prefix)) { - CK::fs::create_directories(this->path); + fs::create_directories(this->path); } void tmp_dir::execute(const std::string& cmd) const @@ -43,6 +43,6 @@ void tmp_dir::execute(const std::string& cmd) const std::system(s.c_str()); } -tmp_dir::~tmp_dir() { CK::fs::remove_all(this->path); } +tmp_dir::~tmp_dir() { fs::remove_all(this->path); } } // namespace rtc diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index e3c8d72590..569afed256 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,4 +1,3 @@ - // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. @@ -282,7 +281,11 @@ int main(int argc, char* argv[]) using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; - using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using CodegenGemmPolicy = ck_tile:: + UniversalGemmPipelineAgBgCrPolicy; + + using CodegenGemmPipeline = + ck_tile::GemmPipelineAGmemBGmemCRegV1; invoke_gemm= 41133 +#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133) || \ + (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 3 && HIP_VERSION_PATCH >= 42131) || \ + (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR > 3) #define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1 #else #define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0 diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 131729992b..8a13c0b060 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); auto k_lds_write_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); auto k_lds_read_window = make_tile_window(k_lds_write_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), k_lds_write_window.get_window_origin(), - Policy::template MakeKRegSliceBlockDescriptor()); + Policy::template MakeKRegBlockDescriptor()); auto k_reg_tensor = make_static_distributed_tensor( Policy::template MakeKRegBlockDescriptor()); @@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); auto v_lds_write_window = - make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); auto v_lds_read_window = make_tile_window(v_lds_write_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), v_lds_write_window.get_window_origin(), - Policy::template MakeVRegSliceBlockDescriptor()); - - auto v_reg_tensor = make_static_distributed_tensor( - Policy::template MakeVRegBlockDescriptor()); + Policy::template MakeVRegBlockDescriptor()); //------------------------------------------------------------------ // KT, Reg ->LDS ->Reg @@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor()); auto shuffled_k_lds_write_window = make_tile_window( - shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); auto kt_lds_read = make_tensor_view( kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor()); @@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR block_sync_lds(); - v_reg_tensor = load_tile(v_lds_read_window); + auto v_reg_tensor = load_tile(v_lds_read_window); block_sync_lds(); //---------------------------- Loop Load in ----------------------------// // Q: HBM ->Reg ->LDS @@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); auto q_lds_window = - make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); auto q_lds_read_window = make_tile_window(q_lds_window.get_bottom_tensor_view(), @@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor()); auto shuffled_q_lds_write_window = make_tile_window( - shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); auto qt_lds_read = make_tensor_view( qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor()); @@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); auto do_lds_window = - make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); auto do_lds_read_window = make_tile_window(do_lds_window.get_bottom_tensor_view(), @@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor()); auto shuffled_do_lds_write_window = make_tile_window( - shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); auto dot_read_lds = make_tensor_view( dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor()); @@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR index_t i_total_loops = 0; index_t seqlen_q_step = seqlen_q_start; - static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); + static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); static_assert(kM0 == kK1, "kM0 should equal to kK1"); - static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); + static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); static_assert(kM0 == kK3, "kM0 should equal to kK3"); constexpr index_t k4_loops = kN0 / kK4; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 3156e4a356..d1b6e6f85b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); auto k_lds_write_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); auto k_lds_read_window = make_tile_window(k_lds_write_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), k_lds_write_window.get_window_origin(), - Policy::template MakeKRegSliceBlockDescriptor()); + Policy::template MakeKRegBlockDescriptor()); auto k_reg_tensor = make_static_distributed_tensor( Policy::template MakeKRegBlockDescriptor()); @@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); auto v_lds_write_window = - make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); auto v_lds_read_window = make_tile_window(v_lds_write_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), v_lds_write_window.get_window_origin(), - Policy::template MakeVRegSliceBlockDescriptor()); - - auto v_reg_tensor = make_static_distributed_tensor( - Policy::template MakeVRegBlockDescriptor()); + Policy::template MakeVRegBlockDescriptor()); //------------------------------------------------------------------ // KT, Reg ->LDS ->Reg @@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor()); auto shuffled_k_lds_write_window = make_tile_window( - shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); auto kt_lds_read = make_tensor_view( kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor()); @@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP block_sync_lds(); - v_reg_tensor = load_tile(v_lds_read_window); + auto v_reg_tensor = load_tile(v_lds_read_window); //---------------------------- Loop Load in ----------------------------// // Q: HBM ->Reg ->LDS auto q_dram_window = @@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); auto q_lds_window = - make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); auto q_lds_read_window = make_tile_window(q_lds_window.get_bottom_tensor_view(), @@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor()); auto shuffled_q_lds_write_window = make_tile_window( - shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); auto qt_lds_read = make_tensor_view( qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor()); @@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); auto do_lds_window = - make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); auto do_lds_read_window = make_tile_window(do_lds_window.get_bottom_tensor_view(), @@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor()); auto shuffled_do_lds_write_window = make_tile_window( - shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); auto dot_read_lds = make_tensor_view( dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor()); @@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP index_t i_total_loops = 0; index_t seqlen_q_step = seqlen_q_start; - static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); + static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); static_assert(kM0 == kK1, "kM0 should equal to kK1"); - static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); + static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); static_assert(kM0 == kK3, "kM0 should equal to kK3"); constexpr index_t k4_loops = kN0 / kK4; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 0afad0446c..d353203e0e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -196,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using QDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType); constexpr index_t kMinVecLoad = 4 / sizeof(QDataType); @@ -215,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using KDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType); constexpr index_t kMinVecLoad = 4 / sizeof(KDataType); @@ -234,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using VDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType); constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; @@ -254,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using OGradDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType); constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType); @@ -315,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; @@ -327,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; return total_pixels / GetAlignmentK(); @@ -338,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; @@ -376,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t K1 = GetAlignmentK(); constexpr index_t K0 = kKPerBlock / K1; @@ -399,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t K1 = GetAlignmentV(); constexpr index_t K0 = kKPerBlock / K1; @@ -422,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t K1 = GetAlignmentQ(); constexpr index_t K0 = kKPerBlock / K1; @@ -445,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t K1 = GetAlignmentOGrad(); constexpr index_t K0 = kKPerBlock / K1; @@ -816,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor() { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); return MakeXLdsBlockDescriptor(); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - constexpr auto k_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); - - return k_block_dstr; - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor() { @@ -865,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; @@ -890,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor() { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kVPack = GetSmemKPackV(); return MakeXLdsBlockDescriptor(); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - constexpr auto v_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); - - return v_block_dstr; - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor() { @@ -940,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; @@ -966,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t K1 = GetAlignmentK(); constexpr index_t K0 = kKPerBlock / K1; @@ -1048,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() { constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackQ(); @@ -1092,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t K1 = GetAlignmentQ(); constexpr index_t K0 = kKPerBlock / K1; @@ -1255,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { // Hold full block data constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kKPack = GetSmemKPackOGrad(); @@ -1299,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t K1 = GetAlignmentOGrad(); constexpr index_t K0 = kKPerBlock / K1; @@ -1859,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0; static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim; static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim; + static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0; + static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2; static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4; static constexpr index_t WarpGemmM = @@ -1873,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy // Compute static constexpr index_t Gemm0MFMA = - kM0 * kN0 * kQKHeaddim / - (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm1MFMA = - kM0 * kN0 * kVHeaddim / - (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); - static constexpr index_t Gemm2MFMA = kN0 * kVHeaddim * kM0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm2MFMA = + kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm3MFMA = kN0 * kQKHeaddim * kM0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); @@ -1903,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ(); static constexpr index_t SGradT_LDS_READ_P1 = kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); - static constexpr index_t Q_LDS_READ = - kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ(); + static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ(); static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t SGradT_LDS_READ_P2 = kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); static constexpr index_t OGrad_LDS_READ = - kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + kM0 * kK2 / kBlockSize / GetAlignmentOGrad(); static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); // LDS Write diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index dc5983e4d1..436d964c37 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -23,6 +23,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..7044a53140 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,424 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile { + +// UniversalGemm Policy +template +struct UniversalGemmPipelineAgBgCrPolicy +{ + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + static constexpr bool TransposeC = true; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using WarpGemm = WarpGemmMfmaDispatcher; + + using ADataType = remove_cvref_t; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t K1 = WarpGemm::kK; + constexpr index_t K0 = KPerBlock / K1; + + if constexpr(std::is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple(K0 * number{}, number{}, K1), + make_tuple(K1, number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(K1)), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(K0, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( + a_lds_block_desc_ak0_kMLdsLayer_m_ak1, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return a_lds_block_desc_m_k; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = Problem::kBlockSize / M0; + constexpr auto K0PerThreadWrite = K0 / KThreadWrite; + constexpr auto KThreadRead = 64 / WarpGemm::kM; + constexpr auto K0PerThreadRead = K0 / KThreadRead; + + constexpr auto kfold = + (K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=kN0 + constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128) + ? 1 + : ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0 + ? M0 + : 128 / (K1 * WarpGemm::kM * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + K1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + K1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return a_lds_block_desc_m_k; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + using WarpGemm = WarpGemmMfmaDispatcher; + + using BDataType = remove_cvref_t; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = WarpGemm::kK; + constexpr index_t K0 = KPerBlock / K1; + + if constexpr(std::is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple(K0 * number{}, number{}, K1), + make_tuple(K1, number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(K1)), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(K0, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( + b_lds_block_desc_bk0_kNLdsLayer_n_bk1, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return b_lds_block_desc_n_k; + } + else // RowMajor B + { + constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = Problem::kBlockSize / N0; + constexpr auto K0PerThreadWrite = K0 / KThreadWrite; + constexpr auto KThreadRead = 64 / WarpGemm::kN; + constexpr auto K0PerThreadRead = K0 / KThreadRead; + + constexpr auto kfold = + (K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=kN0 + constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128) + ? 1 + : ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0 + ? N0 + : 128 / (K1 * WarpGemm::kN * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + K1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + K1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return b_lds_block_desc_n_k; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size(); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() + { + constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * + MakeBLdsBlockDescriptor().get_element_space_size(); + return smem_size_b; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + constexpr index_t smem_size_b = GetSmemSizeB(); + index_t smem_size = 0; + smem_size += smem_size_a + smem_size_b; + + return smem_size; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using WarpGemm = WarpGemmMfmaDispatcher; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = WarpGemm::kK; + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using WarpGemm = WarpGemmMfmaDispatcher; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = WarpGemm::kK; + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = BlockSize / get_warp_size(); + static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = NPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + return BlockGemmASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile