Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm Support : Tile_Layout kernel #1201

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,14 @@ def use_debug_mode():
from torch.utils.cpp_extension import (
CUDA_HOME,
IS_WINDOWS,
ROCM_HOME,
BuildExtension,
CppExtension,
CUDAExtension,
)

IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)

# Constant known variables used throughout this file
cwd = os.path.abspath(os.path.curdir)
third_party_path = os.path.join(cwd, "third_party")
Expand Down Expand Up @@ -201,13 +204,18 @@ def get_extensions():
print(
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
)
if CUDA_HOME is None and torch.cuda.is_available():
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")

if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
print(
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
)
print(
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
)

use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
use_cuda = torch.cuda.is_available() and (
CUDA_HOME is not None or ROCM_HOME is not None
)
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = []
Expand All @@ -226,7 +234,8 @@ def get_extensions():

if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
if "nvcc" in extra_compile_args:
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])
else:
extra_compile_args["cxx"] = ["/O2" if not debug_mode else "/Od", "/permissive-"]
Expand Down Expand Up @@ -258,9 +267,31 @@ def get_extensions():
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
)

if use_cuda:
extensions_hip_dir = os.path.join(
extensions_dir, "cuda", "tensor_core_tiled_layout"
)
hip_sources = list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
)

if not IS_ROCM and use_cuda:
sources += cuda_sources

# TOOD: Remove this and use what CUDA has once we fix all the builds.
if IS_ROCM and use_cuda:
# Add ROCm GPU architecture check
gpu_arch = torch.cuda.get_device_properties(0).name
if gpu_arch != "gfx942":
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
print(
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
)
return None
sources += hip_sources

if len(sources) == 0:
return None

ext_modules = []
if len(sources) > 0:
ext_modules.append(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800

#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
Expand All @@ -7,13 +7,24 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#if defined(USE_ROCM)
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#endif

template <typename U, typename V>
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
const uint64_t blocks = a / b + (a % b != 0);
return blocks;
}

#if defined(USE_ROCM)
constexpr int32_t kWarpSize = 64;
#else
constexpr int32_t kWarpSize = 32;
#endif

//Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization
//https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180
Expand All @@ -30,38 +41,71 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
uint32_t const source_i4s = source;

// First, we extract the i4s and construct an intermediate fp16 number.
#if !defined(USE_ROCM)
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
#endif
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;

// We don't have enough mantissa to remove as much shift overhead as FP16, so
// we must loop. No shift needed for first item.
uint32_t i4s = source_i4s;
// AMD MI300X ISA that performs two bitwise operations in a single instruction:
// v_and_or_b32 performs H[0] = (i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM
// - First ANDs `i4s` with `MASK` (0x000f000f) to extract 4-bit values
// - Then ORs the result with `I4s_TO_BF16s_MAGIC_NUM` (0x43004300) to convert them to bfloat16
#if defined(USE_ROCM)
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(h[0])
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
#else
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#endif

#pragma unroll
for (int ii = 1; ii < kElements / 2; ++ii) {
i4s >>= 4; // or is it 8?
// (i4s & 0x000f000f) | 0x43004300
#if defined(USE_ROCM)
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(h[ii])
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
#else
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[ii])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#endif
}

// This is the BF16 {-136, -136} represented as an integer.
static constexpr uint32_t BF16_BIAS = 0xC308C308;
static constexpr uint32_t BF16_ONE = 0x3F803F80;
#if defined(USE_ROCM)
#if ROCM_VERSION >= 60200
auto BF16_SCALE_FACTOR = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
auto BF16_UNIT_VALUE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
#else
auto BF16_SCALE_FACTOR = __bfloat162bfloat162(__hip_bfloat16{0xC308});
auto BF16_UNIT_VALUE = __bfloat162bfloat162(__hip_bfloat16{0x3F80});
#endif
#else
static constexpr uint32_t BF16_SCALE_FACTOR = 0xC308C308;
static constexpr uint32_t BF16_UNIT_VALUE = 0x3F803F80;
#endif

// Finally, we construct the output numbers.
#pragma unroll
for (int ii = 0; ii < kElements / 2; ++ii) {
// Since this section is for Ampere+, we use bf16 fma to do the bias
// subtraction
#if defined(USE_ROCM)
result.vals[ii] = __hfma2(result.vals[ii], BF16_UNIT_VALUE, BF16_SCALE_FACTOR);
#else
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
: "=r"(h[ii])
: "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
: "r"(h[ii]), "r"(BF16_UNIT_VALUE), "r"(BF16_SCALE_FACTOR));
#endif
}

return result;
Expand Down Expand Up @@ -123,11 +167,22 @@ __global__ void _dequantize_int4_kernel(
// All b values within a 16x16 tile should fall within the same q group
// Hence we load 1 scale and zero per loop
int qgroup = ks[0] / groupSize;
#if defined(USE_ROCM)
__nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f));
__nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f));

if (scales_and_zeros) {
const auto& sz = *scales_and_zeros;
const __nv_bfloat16* pSZ = reinterpret_cast<const __nv_bfloat16*>(&sz[qgroup][n0][0]);

scale2 = __bfloat162bfloat162(pSZ[0]);
zero2 = __bfloat162bfloat162(pSZ[1]);
}
#else
const __nv_bfloat16 *pSZ = reinterpret_cast<const __nv_bfloat16*>(&scales_and_zeros.value()[qgroup][n0][0]);

// Vectorize scales and zeros
__nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]);
__nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]);
#endif

#pragma unroll
for (int i = 0; i < 4; i++) {
Expand Down
Loading