Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce gemm_softmax_gemm to codegen #1542

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"

namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {

// defines all values need for an instance of fwd conv
struct Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static std::vector<std::vector<Operation_Xdl_CShuffle>>
CreateOperations(const std::string& prologue, const std::string& epilogue);
// returns a vector of instances, given a problem spec and fusion operators
static std::vector<Operation_Xdl_CShuffle>
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
TensorDesc A{};
TensorDesc B{};
TensorDesc B1{};
TensorDesc C{};
DataType acc = DataType::Float;
DataType cs_type = DataType::Half;
std::string a_elem_op = PassThrough;
std::string b_elem_op = PassThrough;
std::string b1_elem_op = PassThrough;
std::string c_elem_op = PassThrough;
std::string acc_elem_op = Scale;
std::string prologue = "";
std::string epilogue = "";
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
// tuning parameters
operation::TileDescGemmGemm tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::BlockTransferDesc b0_block_transfer{};
operation::BlockTransferDesc b1_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};

bool mask_out_upper_triangle = false;

// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue);
/**constexpr**/ bool
IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_);
// returns a templated instance
Solution ToSolution() const;
};

} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"

namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {

// defines the problem specification for a GEMM operation
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransB = false;
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
std::string AElementOp = PassThrough;
std::string BElementOp = PassThrough;
std::string B1ElementOp = PassThrough;
std::string CElementOp = PassThrough;
std::string AccElementOp = Scale;

// returns the correct device op file for the operation
std::string GetIncludeHeader() const;

// returns a list of instances based on the problem spec and provided fusion operations
std::vector<Solution> GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const;
};

} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
LoopScheduler loop_scheduler{};
PipelineVersion pipeline_version{};

// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
Expand Down
20 changes: 20 additions & 0 deletions codegen/include/ck/host/operation/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@ struct TileDesc
int n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};

struct TileDescGemmGemm
{
int block_size = 0;
int gemm01_m_per_block = 0;
int gemm0_n_per_block = 0;
int gemm0_k_per_block = 0;
int gemm1_n_per_block = 0;
int gemm1_k_per_block = 0;
int ak1 = 0;
int bk1 = 0;
int b1k1 = 0;
int m_per_XDL = 0;
int n_per_XDL = 0;
int gemm0_m_Xdl_per_wave = 0;
int gemm0_n_Xdl_per_wave = 0;
int gemm1_n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};

struct BlockTransferDesc
{
std::string thread_cluster_length = "";
Expand Down
15 changes: 15 additions & 0 deletions codegen/include/ck/host/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,20 @@ enum class GemmType
};
std::string ToString(GemmType gt);

enum class LoopScheduler
{
Default,
Interwave,
};
std::string ToString(LoopScheduler ls);

enum class PipelineVersion
{
v1,
v2
};
std::string ToString(PipelineVersion pv);

struct TensorDesc
{
DataType element;
Expand All @@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});

constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale";

} // namespace host
} // namespace ck
38 changes: 38 additions & 0 deletions codegen/src/device_batched_gemm_softmax_gemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>

namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {

// return the relevant device op file based on the operation
std::string Problem::GetIncludeHeader() const
{
return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp";
}

// returns templated instances when provided with a problem specification
std::vector<Solution> Problem::GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const
{
if(get_xdlop_archs().count(arch) == 0)
return {};
auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations(
*this, prologue, epilogue); // obtains vector of instances
std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution(); // template instance with correct values
});
return result;
}

} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
Loading