Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce gemm_softmax_gemm to codegen #1542

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
28 changes: 28 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ rocm_setup_version(VERSION ${version})

list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip "$ENV{ROCM_PATH}" "$ENV{HIP_PATH}")

option(CK_BUILD_HOST_LIB, "Only build the CK JIT Helper Library" OFF)
mirza-halilcevic marked this conversation as resolved.
Show resolved Hide resolved

message("GPU_TARGETS= ${GPU_TARGETS}")
message("GPU_ARCHS= ${GPU_ARCHS}")
if(GPU_ARCHS)
Expand All @@ -137,6 +139,7 @@ if(GPU_TARGETS)
else()
set(USER_GPU_TARGETS 0)
endif()

find_package(hip)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
Expand Down Expand Up @@ -246,6 +249,7 @@ elseif(CK_PARALLEL_COMPILE_JOBS)
message(WARNING "Job pooling is only available with Ninja generators.")
endif()

if (NOT CK_BUILD_HOST_LIB)

option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF)
Expand All @@ -267,6 +271,8 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
link_libraries(Threads::Threads)

endif() # NOT CK_BUILD_HOST_LIB

## C++
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Expand All @@ -283,6 +289,8 @@ if(USE_GLIBCXX_ASSERTIONS)
add_compile_options(-Wp,-D_GLIBCXX_ASSERTIONS)
endif()

if (NOT CK_BUILD_HOST_LIB)

## HIP
set(CMAKE_HIP_PLATFORM amd)
set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER})
Expand Down Expand Up @@ -338,6 +346,8 @@ else()
add_compile_definitions(__HIP_PLATFORM_HCC__=1)
endif()

endif() # NOT CK_BUILD_HOST_LIB

## tidy
include(EnableCompilerWarnings)
set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
Expand Down Expand Up @@ -491,13 +501,17 @@ include_directories(BEFORE
${HIP_INCLUDE_DIRS}
)

if (NOT CK_BUILD_HOST_LIB)

SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")

endif() # NOT CK_BUILD_HOST_LIB

if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
add_compile_options(-fcolor-diagnostics)
endif()
Expand All @@ -507,6 +521,8 @@ endif()

add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})

if (NOT CK_BUILD_HOST_LIB)

file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
set(CK_DEVICE_INSTANCES)
Expand Down Expand Up @@ -575,6 +591,18 @@ if(CK_USE_CODEGEN AND (GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
add_subdirectory(codegen)
endif()

else() # NOT CK_BUILD_HOST_LIB

if(GPU_TARGETS MATCHES "gfx9")
rocm_package_setup_component(ck_host
LIBRARY_NAME composablekernel
PACKAGE_NAME ck_host
)
add_subdirectory(codegen)
endif()

endif() # NOT CK_BUILD_HOST_LIB

#Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers)

Expand Down
2 changes: 1 addition & 1 deletion Config.cmake.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@PACKAGE_INIT@

set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility)
set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility ck_host)

foreach(_comp ${composable_kernel_FIND_COMPONENTS})
if(NOT _comp IN_LIST _composable_kernel_supported_components)
Expand Down
2 changes: 1 addition & 1 deletion codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ rocm_export_targets(
NAMESPACE composable_kernel::
)

if(BUILD_TESTING)
if(BUILD_TESTING AND NOT CK_BUILD_HOST_LIB)
add_subdirectory(test)
endif()

Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"

namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {

// defines all values need for an instance of fwd conv
struct Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static std::vector<std::vector<Operation_Xdl_CShuffle>>
CreateOperations(const std::string& prologue, const std::string& epilogue);
// returns a vector of instances, given a problem spec and fusion operators
static std::vector<Operation_Xdl_CShuffle>
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
TensorDesc A{};
TensorDesc B{};
TensorDesc B1{};
TensorDesc C{};
DataType acc = DataType::Float;
DataType cs_type = DataType::Half;
std::string a_elem_op = PassThrough;
std::string b_elem_op = PassThrough;
std::string b1_elem_op = PassThrough;
std::string c_elem_op = PassThrough;
std::string acc_elem_op = Scale;
std::string prologue = "";
std::string epilogue = "";
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
// tuning parameters
operation::TileDescGemmGemm tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::BlockTransferDesc b0_block_transfer{};
operation::BlockTransferDesc b1_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};

bool mask_out_upper_triangle = false;

// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue);
/**constexpr**/ bool
IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_);
// returns a templated instance
Solution ToSolution() const;
};

} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"

namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {

// defines the problem specification for a GEMM operation
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransB = false;
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
std::string AElementOp = PassThrough;
std::string BElementOp = PassThrough;
std::string B1ElementOp = PassThrough;
std::string CElementOp = PassThrough;
std::string AccElementOp = Scale;

// returns the correct device op file for the operation
std::string GetIncludeHeader() const;

// returns a list of instances based on the problem spec and provided fusion operations
std::vector<Solution> GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const;
};

} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
LoopScheduler loop_scheduler{};
PipelineVersion pipeline_version{};

// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
Expand Down
20 changes: 20 additions & 0 deletions codegen/include/ck/host/operation/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@ struct TileDesc
int n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};

struct TileDescGemmGemm
{
int block_size = 0;
int gemm01_m_per_block = 0;
int gemm0_n_per_block = 0;
int gemm0_k_per_block = 0;
int gemm1_n_per_block = 0;
int gemm1_k_per_block = 0;
int ak1 = 0;
int bk1 = 0;
int b1k1 = 0;
int m_per_XDL = 0;
int n_per_XDL = 0;
int gemm0_m_Xdl_per_wave = 0;
int gemm0_n_Xdl_per_wave = 0;
int gemm1_n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};

struct BlockTransferDesc
{
std::string thread_cluster_length = "";
Expand Down
15 changes: 15 additions & 0 deletions codegen/include/ck/host/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,20 @@ enum class GemmType
};
std::string ToString(GemmType gt);

enum class LoopScheduler
{
Default,
Interwave,
};
std::string ToString(LoopScheduler ls);

enum class PipelineVersion
{
v1,
v2
};
std::string ToString(PipelineVersion pv);

struct TensorDesc
{
DataType element;
Expand All @@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});

constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale";

} // namespace host
} // namespace ck
38 changes: 38 additions & 0 deletions codegen/src/device_batched_gemm_softmax_gemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>

namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {

// return the relevant device op file based on the operation
std::string Problem::GetIncludeHeader() const
{
return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp";
}

// returns templated instances when provided with a problem specification
std::vector<Solution> Problem::GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const
{
if(get_xdlop_archs().count(arch) == 0)
return {};
auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations(
*this, prologue, epilogue); // obtains vector of instances
std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution(); // template instance with correct values
});
return result;
}

} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
Loading