diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp new file mode 100644 index 0000000000..301df0a529 --- /dev/null +++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/operation/gemm.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// defines all values need for an instance of fwd conv +struct Operation_Xdl_CShuffle +{ + // returns a vector of instances, only given fusion operators: will use default problem spec + static std::vector> + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, given a problem spec and fusion operators + static std::vector + CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue); + TensorDesc A{}; + TensorDesc B{}; + TensorDesc B1{}; + TensorDesc C{}; + DataType acc = DataType::Float; + DataType cs_type = DataType::Half; + std::string a_elem_op = PassThrough; + std::string b_elem_op = PassThrough; + std::string b1_elem_op = PassThrough; + std::string c_elem_op = PassThrough; + std::string acc_elem_op = Scale; + std::string prologue = ""; + std::string epilogue = ""; + std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; + // tuning parameters + operation::TileDescGemmGemm tile_desc{}; + operation::BlockTransferDesc a_block_transfer{}; + operation::BlockTransferDesc b0_block_transfer{}; + operation::BlockTransferDesc b1_block_transfer{}; + operation::CShuffleDesc cshuffle{}; + operation::CBlockTransferDesc c_block_transfer{}; + + bool mask_out_upper_triangle = false; + + // functions to update fusion operators if provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + /**constexpr**/ bool + IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_); + // returns a templated instance + Solution ToSolution() const; +}; + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp new file mode 100644 index 0000000000..428034a3ba --- /dev/null +++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// defines the problem specification for a GEMM operation +struct Problem +{ + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + std::size_t O = 0; + bool TransA = false; + bool TransB = false; + bool TransB1 = false; + bool TransC = false; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType B1DataType = DataType::Half; + DataType CDataType = DataType::Half; + std::string AElementOp = PassThrough; + std::string BElementOp = PassThrough; + std::string B1ElementOp = PassThrough; + std::string CElementOp = PassThrough; + std::string AccElementOp = Scale; + + // returns the correct device op file for the operation + std::string GetIncludeHeader() const; + + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const; +}; + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp index 359da7d8cf..e5eeb6be15 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp @@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle operation::BlockTransferDesc b_block_transfer{}; operation::CShuffleDesc cshuffle{}; operation::CBlockTransferDesc c_block_transfer{}; + LoopScheduler loop_scheduler{}; + PipelineVersion pipeline_version{}; // functions to update fusion operators if provided void update_prologue(const std::string& prologue); diff --git a/codegen/include/ck/host/operation/gemm.hpp b/codegen/include/ck/host/operation/gemm.hpp index 84ef92f0a0..5a51a0002e 100644 --- a/codegen/include/ck/host/operation/gemm.hpp +++ b/codegen/include/ck/host/operation/gemm.hpp @@ -23,6 +23,26 @@ struct TileDesc int n_Xdl_per_wave = 0; int num_gemmk_prefetch_stage = 0; }; + +struct TileDescGemmGemm +{ + int block_size = 0; + int gemm01_m_per_block = 0; + int gemm0_n_per_block = 0; + int gemm0_k_per_block = 0; + int gemm1_n_per_block = 0; + int gemm1_k_per_block = 0; + int ak1 = 0; + int bk1 = 0; + int b1k1 = 0; + int m_per_XDL = 0; + int n_per_XDL = 0; + int gemm0_m_Xdl_per_wave = 0; + int gemm0_n_Xdl_per_wave = 0; + int gemm1_n_Xdl_per_wave = 0; + int num_gemmk_prefetch_stage = 0; +}; + struct BlockTransferDesc { std::string thread_cluster_length = ""; diff --git a/codegen/include/ck/host/types.hpp b/codegen/include/ck/host/types.hpp index 8bad7bf89c..b05e134176 100644 --- a/codegen/include/ck/host/types.hpp +++ b/codegen/include/ck/host/types.hpp @@ -66,6 +66,20 @@ enum class GemmType }; std::string ToString(GemmType gt); +enum class LoopScheduler +{ + Default, + Interwave, +}; +std::string ToString(LoopScheduler ls); + +enum class PipelineVersion +{ + v1, + v2 +}; +std::string ToString(PipelineVersion pv); + struct TensorDesc { DataType element; @@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...}); constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; +constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale"; } // namespace host } // namespace ck diff --git a/codegen/src/device_batched_gemm_softmax_gemm.cpp b/codegen/src/device_batched_gemm_softmax_gemm.cpp new file mode 100644 index 0000000000..cf140ead1d --- /dev/null +++ b/codegen/src/device_batched_gemm_softmax_gemm.cpp @@ -0,0 +1,38 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// return the relevant device op file based on the operation +std::string Problem::GetIncludeHeader() const +{ + return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"; +} + +// returns templated instances when provided with a problem specification +std::vector Problem::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const +{ + if(get_xdlop_archs().count(arch) == 0) + return {}; + auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations( + *this, prologue, epilogue); // obtains vector of instances + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { + return op.ToSolution(); // template instance with correct values + }); + return result; +} + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp new file mode 100644 index 0000000000..b12c2e1a4a --- /dev/null +++ b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp @@ -0,0 +1,408 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// calculate appropriate Gemm Specification based on input tensor dimensions +std::string GetGemmSpec(const std::size_t m, + const std::size_t n, + const std::size_t k, + const std::size_t n1, + const std::size_t m_per_block, + const std::size_t n_per_block, + const std::size_t k_per_block, + const std::size_t n1_per_block) +{ + std::string spec = ""; + if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) + spec += "M"; + if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) + spec += "N"; + if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) + spec += "K"; + if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0) + spec += "O"; + if(spec == "") + return "ck::tensor_operation::device::GemmSpecialization::Default"; + + return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; +} + +// function to update prologue/epilogue with user provided operation +void Operation_Xdl_CShuffle::update_prologue(const std::string& pro) +{ + if(!prologue.empty()) + { + this->prologue = pro; + } + else + { + this->prologue = ""; + } +} + +void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) +{ + if(!epilogue.empty()) + { + this->epilogue = epi; + } + else + { + this->epilogue = ""; + } +} + +// accounts for all possible combinations of Row/Col major +static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Xdl_CShuffle::CreateOperations( + const Problem& prob, const std::string& prologue, const std::string& epilogue) +{ + std::vector result; + + std::vector tile_descriptions = { + // clang-format off +// Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| NumGemmK| +// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch| +// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage| +// | | | | | | | | | | | Wave| Wave| Wave| | + { 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1}, + { 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1}, + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1}, + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1}, + { 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, + { 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, + { 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, + { 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, +// Padded fallback kernel + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1}, +// Irregular k + { 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1}, + { 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1}, + { 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1}, + { 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1}, + { 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1}, + // clang-format on + }; + + const std::vector a_block_descriptions = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, +// Padded fallback kernel + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, +// Irregular k + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + // clang-format on + }; + + const std::vector b1_block_descriptions = { + // clang-format off +// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, +// Padded fallback kernel + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, +// Irregular k + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + // clang-format on + }; + + std::vector cshuffle_descriptions = { + // clang-format off +// CShuffle| CShuffle| +// MXdlPerWave| NXdlPerWave| +// PerShuffle| PerShuffle| +// | | + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 8}, + { 1, 4}, + { 1, 8}, + { 1, 4}, +// Padded fallback kernel + { 1, 2}, + { 1, 2}, +// Irregular k + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + // clang-format on + }; + + std::vector c_block_descriptions = { + // clang-format off +// CBlockTransferClusterLengths| CBlockTransfer +// _MBlock_MWaveMPerXdl| ScalarPerVector +// _NBlock_NWaveNPerXdl| _NWaveNPerXdl +// | + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1,16>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1,16>, 8}, + { S<1, 32, 1, 8>, 8}, +// Padded fallback kernel + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, +// Irregular k + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + // clang-format on + }; + + assert(tile_descriptions.size() == a_block_descriptions.size()); + assert(tile_descriptions.size() == b1_block_descriptions.size()); + assert(tile_descriptions.size() == cshuffle_descriptions.size()); + assert(tile_descriptions.size() == c_block_descriptions.size()); + + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b0_block_transfer = a_block_descriptions[i]; // b0 same as a + x.b1_block_transfer = b1_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.B1 = TensorDesc{prob.B1DataType, ToLayout(prob.TransB1)}; + x.C = TensorDesc{prob.CDataType, ToLayout(prob.TransC)}; + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.b1_elem_op = prob.B1ElementOp; + x.c_elem_op = prob.CElementOp; + x.acc_elem_op = prob.AccElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + prob.O, + x.tile_desc.gemm01_m_per_block, + x.tile_desc.gemm0_n_per_block, + x.tile_desc.gemm0_k_per_block, + x.tile_desc.gemm1_n_per_block); + x.update_prologue(prologue); + x.update_epilogue(epilogue); + x.mask_out_upper_triangle = true; + result.push_back(x); + + x.mask_out_upper_triangle = false; + result.push_back(x); + } + return result; +} + +// set up instances when not provided with a problem specification, use default operation values and +// all possible layout combinations +std::vector> +Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue) +{ + Problem prob; + prob.TransA = false; + prob.TransB = true; + prob.TransB1 = false; + prob.TransC = false; + + return {CreateOperations(prob, prologue, epilogue)}; +} + +static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate = + "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${LayoutA}, " + "${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, " + "${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, " + "${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, " + "${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, " + "${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, " + "${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, " + "${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, " + "${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, " + "${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, " + "${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, " + "${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, " + "${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, " + "${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, " + "${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, " + "${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, " + "${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, " + "${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, " + "${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, " + "${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " + "${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, " + "${CBlockTransferScalarPerVector_NWaveNPerXdl}, ${MaskOutUpperTriangle}>"; + +// use hardcoded instances from vector of operations to substitute values into instance template +Solution Operation_Xdl_CShuffle::ToSolution() const +{ + std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.gemm01_m_per_block) + "_" + + std::to_string(this->tile_desc.gemm0_n_per_block) + "_" + + std::to_string(this->tile_desc.gemm0_k_per_block) + "_" + + std::to_string(this->tile_desc.gemm1_n_per_block) + "_" + + std::to_string(this->tile_desc.gemm1_k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.b1k1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)}, + {"LayoutA", ToString(this->A.layout)}, + {"LayoutB0", ToString(this->B.layout)}, + {"LayoutB1", ToString(this->B1.layout)}, + {"LayoutC", ToString(this->C.layout)}, + {"ADataType", ToString(this->A.element)}, + {"B0DataType", ToString(this->B.element)}, + {"B1DataType", ToString(this->B1.element)}, + {"CDataType", ToString(this->C.element)}, + {"AccDataType", ToString(this->acc)}, + {"CShuffleDataType", ToString(this->cs_type)}, + {"AElementwiseOperation", this->a_elem_op}, + {"B0ElementwiseOperation", this->b_elem_op}, + {"Acc0ElementwiseOperation", this->acc_elem_op}, + {"B1ElementwiseOperation", this->b1_elem_op}, + {"CElementwiseOperation", this->c_elem_op}, + {"GemmSpecialization", this->gemm_specialization}, + {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, + {"BlockSize", std::to_string(this->tile_desc.block_size)}, + {"Gemm01MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)}, + {"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)}, + {"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)}, + {"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)}, + {"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)}, + {"AK1", std::to_string(this->tile_desc.ak1)}, + {"BK1", std::to_string(this->tile_desc.bk1)}, + {"B1K1", std::to_string(this->tile_desc.b1k1)}, + {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, + {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, + {"Gemm0MXdlPerWave", std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave)}, + {"Gemm0NXdlPerWave", std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave)}, + {"Gemm1NXdlPerWave", std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)}, + {"ABlockTransferThreadClusterLengths_AK0_M_AK1", + this->a_block_transfer.thread_cluster_length}, + {"ABlockTransferThreadClusterArrangeOrder", + this->a_block_transfer.thread_cluster_arrange_order}, + {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, + {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, + {"ABlockTransferSrcScalarPerVector", + std::to_string(this->a_block_transfer.src_scalar_per_vector)}, + {"ABlockTransferDstScalarPerVector_AK1", + std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, + {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, + {"B0BlockTransferThreadClusterLengths_BK0_N_BK1", + this->b0_block_transfer.thread_cluster_length}, + {"B0BlockTransferThreadClusterArrangeOrder", + this->b0_block_transfer.thread_cluster_arrange_order}, + {"B0BlockTransferSrcAccessOrder", this->b0_block_transfer.src_access_order}, + {"B0BlockTransferSrcVectorDim", std::to_string(this->b0_block_transfer.src_vec_dim)}, + {"B0BlockTransferSrcScalarPerVector", + std::to_string(this->b0_block_transfer.src_scalar_per_vector)}, + {"B0BlockTransferDstScalarPerVector_BK1", + std::to_string(this->b0_block_transfer.dst_scalar_per_vector_k1)}, + {"B0BlockLdsExtraN", std::to_string(this->b0_block_transfer.lds_add_extra_dim)}, + {"B1BlockTransferThreadClusterLengths_BK0_N_BK1", + this->b1_block_transfer.thread_cluster_length}, + {"B1BlockTransferThreadClusterArrangeOrder", + this->b1_block_transfer.thread_cluster_arrange_order}, + {"B1BlockTransferSrcAccessOrder", this->b1_block_transfer.src_access_order}, + {"B1BlockTransferSrcVectorDim", std::to_string(this->b1_block_transfer.src_vec_dim)}, + {"B1BlockTransferSrcScalarPerVector", + std::to_string(this->b1_block_transfer.src_scalar_per_vector)}, + {"B1BlockTransferDstScalarPerVector_BK1", + std::to_string(this->b1_block_transfer.dst_scalar_per_vector_k1)}, + {"B1BlockLdsExtraN", std::to_string(this->b1_block_transfer.lds_add_extra_dim)}, + {"CShuffleMXdlPerWavePerShuffle", + std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, + {"CShuffleNXdlPerWavePerShuffle", + std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, + {"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl", + this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, + {"CBlockTransferScalarPerVector_NWaveNPerXdl", + std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + {"MaskOutUpperTriangle", std::to_string(this->mask_out_upper_triangle)}, + }; + + return Solution{InterpolateString(DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate, values), + std::move(values)}; +} + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp index fff75c1962..f4b61ee99a 100644 --- a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -62,6 +62,13 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) // accounts for all possible combinations of Row/Col major static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + + +// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, + +// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + + // Hard-code tuning parameters in modularized fashion, string them together into a vector of // instances std::vector Operation_Xdl_CShuffle::CreateOperations( @@ -83,6 +90,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, +// Irregular tile + { 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1}, // clang-format on }; @@ -100,6 +109,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, // clang-format on }; @@ -109,15 +120,17 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| // Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | // | | | | | | | + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, // clang-format on - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, }; std::vector b_block_descriptions_rowmajor = { @@ -134,6 +147,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, // clang-format on }; @@ -151,6 +166,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, // clang-format on }; @@ -167,6 +184,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { 1, 1}, { 1, 1}, { 1, 1}, + { 1, 1}, { 1, 1}, // clang-format on }; @@ -185,6 +203,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<1, 16, 1, 8>, 8}, { S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8}, +// Irregular tile + { S<1, 16, 1, 4>, 1}, // clang-format on }; @@ -199,33 +219,44 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( assert(tile_descriptions.size() == cshuffle_descriptions.size()); assert(tile_descriptions.size() == c_block_descriptions.size()); - // Put all values together into a single operation > store into the result vector - for(std::size_t i = 0; i < tile_descriptions.size(); i++) + const std::vector> scheduler_pipeline_descriptions = + { + {LoopScheduler::Default, PipelineVersion::v1}, + {LoopScheduler::Interwave, PipelineVersion::v1}, + {LoopScheduler::Default, PipelineVersion::v2}, + }; + for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions) { - Operation_Xdl_CShuffle x; - x.tile_desc = tile_descriptions[i]; - x.a_block_transfer = a_block_descriptions[i]; - x.b_block_transfer = b_block_descriptions[i]; - x.cshuffle = cshuffle_descriptions[i]; - x.c_block_transfer = c_block_descriptions[i]; - x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; - x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; - x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; - x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { - return TensorDesc{dt, ToLayout(trans)}; - }); - x.a_elem_op = prob.AElementOp; - x.b_elem_op = prob.BElementOp; - x.cde_elem_op = prob.CDEElementOp; - x.gemm_specialization = GetGemmSpec(prob.M, - prob.N, - prob.K, - x.tile_desc.m_per_block, - x.tile_desc.n_per_block, - x.tile_desc.k_per_block); - x.update_prologue(prologue); - x.update_epilogue(epilogue); - result.push_back(x); + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b_block_transfer = b_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; + x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { + return TensorDesc{dt, ToLayout(trans)}; + }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + x.tile_desc.m_per_block, + x.tile_desc.n_per_block, + x.tile_desc.k_per_block); + x.loop_scheduler = loop_scheduler; + x.pipeline_version = pipeline_version; + x.update_prologue(prologue); + x.update_epilogue(epilogue); + result.push_back(x); + } } return result; } @@ -263,7 +294,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = "${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, " "${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " - "${CDEBlockTransferScalarPerVector_NPerBlock}>"; + "${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>"; // use hardcoded instances from vector of operations to substitute values into instance template Solution Operation_Xdl_CShuffle::ToSolution() const @@ -336,6 +367,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, {"CDEBlockTransferScalarPerVector_NPerBlock", std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + {"LoopScheduler", ToString(this->loop_scheduler)}, + {"PipelineVersion", ToString(this->pipeline_version)}, }; return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values), diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp index a8a8b10c04..4757cab536 100644 --- a/codegen/src/types.cpp +++ b/codegen/src/types.cpp @@ -56,6 +56,26 @@ std::string ToString(GemmType gt) throw std::runtime_error("Incorrect gemm type"); } +std::string ToString(LoopScheduler ls) +{ + switch(ls) + { + case LoopScheduler::Default: return "ck::LoopScheduler::Default"; + case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave"; + } + throw std::runtime_error("Incorrect LoopScheduler type"); +} + +std::string ToString(PipelineVersion pv) +{ + switch(pv) + { + case PipelineVersion::v1: return "ck::PipelineVersion::v1"; + case PipelineVersion::v2: return "ck::PipelineVersion::v2"; + } + throw std::runtime_error("Incorrect PipelineVersion type"); +} + std::string SequenceStr(const std::vector& v) { return "ck::Sequence<" + diff --git a/codegen/test/rtc/include/rtc/hip.hpp b/codegen/test/rtc/include/rtc/hip.hpp index 6b523382dc..e962d4cd3e 100644 --- a/codegen/test/rtc/include/rtc/hip.hpp +++ b/codegen/test/rtc/include/rtc/hip.hpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace rtc { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 9af1a44781..ce0583800c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -611,6 +611,96 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return true; } + static constexpr bool + IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_) + { + // check vector load/store + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v) + { + if(KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of B + if constexpr(is_same_v) + { + if(NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of B1 + if constexpr(is_same_v) + { + if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of C + if constexpr(is_same_v) + { + if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + + return true; + } + static bool IsSupportedArgument(const Argument& arg) { if(!ck::is_xdl_supported()) @@ -625,29 +715,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; - // Check scalar per vector requirement - const auto a_extent_lowest = - is_same_v ? KRaw : MRaw; - const auto b_extent_lowest = - is_same_v ? NRaw : KRaw; - const auto b1_extent_lowest = - is_same_v ? Gemm1NRaw : NRaw; - const auto c_extent_lowest = - is_same_v ? Gemm1NRaw : MRaw; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - return false; - } - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + arg.block_2_ctile_map_) and + IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw); } // polymorphic @@ -765,6 +838,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return str.str(); } + + template + struct Descriptor + { + template + static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc) + { + const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc) + { + const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc) + { + const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc); + + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc) + { + return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc); + } + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using B1GridDesc_BK0_N_BK1 = + remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + matrix_padder.PadN, + MaskOutUpperTriangle>; + + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1; + CGridDesc_M_N c_grid_desc_m_n; + C0MatrixMask c0_matrix_mask; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_descriptor_mblock_mperblock_nblock_nperblock; + + // element-wise op + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + B1ElementwiseOperation b1_element_op; + CElementwiseOperation c_element_op; + + bool has_main_k_block_loop = true; + bool is_valid = false; + + constexpr Descriptor(ADesc a, + BDesc b, + B1Desc b1, + CDesc c, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + B1ElementwiseOperation b1_element_op_, + CElementwiseOperation c_element_op_) + : a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)}, + b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, + b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, + c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, + block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)}, + c_grid_descriptor_mblock_mperblock_nblock_nperblock{ + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n)}, + has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( + a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, + c0_matrix_mask{c.GetLength(I1)}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + b1_element_op{b1_element_op_}, + c_element_op{c_element_op_}, + is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_m_n, + block_2_ctile_map) and + IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), + b_grid_desc_bk0_n_bk1.GetLength(I1), + a_grid_desc_ak0_m_ak1.GetLength(I0) * + a_grid_desc_ak0_m_ak1.GetLength(I2), + b1_grid_desc_bk0_n_bk1.GetLength(I1))} + { + } + + constexpr bool IsValid() const { return is_valid; } + }; + + template + static constexpr auto + make_descriptor(ADesc a, + BDesc b, + B1Desc b1, + CDesc c, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{}, + CElementwiseOperation c_element_op = CElementwiseOperation{}) + { + return Descriptor( + a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op); + } + + template + __device__ static void Run(const Desc& desc, + const float scale, + const ADataType* __restrict__ p_a_grid, + const ADataType* __restrict__ p_b_grid, + const ADataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid) + { +#ifndef __HIPCC_RTC__ + assert(desc.is_valid); +#endif + __shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()]; + AccElementwiseOperation acc_element_op{scale}; + + if(desc.has_main_k_block_loop) + { + Desc::GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + acc_element_op, + desc.b1_element_op, + desc.c_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.b1_grid_desc_bk0_n_bk1, + desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, + desc.block_2_ctile_map, + desc.c0_matrix_mask); + } + else + { + Desc::GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + acc_element_op, + desc.b1_element_op, + desc.c_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.b1_grid_desc_bk0_n_bk1, + desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, + desc.block_2_ctile_map, + desc.c0_matrix_mask); + } + } }; } // namespace device