From d6763013a7ea555c835a831ea0a40e0a0640370a Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Tue, 12 Nov 2024 13:39:50 -0600 Subject: [PATCH 1/5] Update git ignore --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 5c201d1b34..5fa7064cbe 100644 --- a/.gitignore +++ b/.gitignore @@ -371,4 +371,7 @@ venv/ sweep/ # Model checkpoints -checkpoints/ \ No newline at end of file +checkpoints/ + +# Experimental +torchao/experimental/cmake-out From 45206f4f9d048e64d89e9ab9a97c009205170f6f Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 23 Oct 2024 21:19:33 -0500 Subject: [PATCH 2/5] [experimental] Add Kleidi compile def at the top level --- torchao/experimental/CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index b641c07519..a90cc5884a 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -23,6 +23,10 @@ if(NOT TORCHAO_INCLUDE_DIRS) endif() option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF) +if(TORCHAO_BUILD_KLEIDIAI) + message(STATUS "Building with Arm KleidiAI library") + add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1) +endif() include(CMakePrintHelpers) add_compile_options("-Wall" "-Werror" "-Wno-deprecated") From b339c12f2657ccc8d038185bff6ac8024f7fbcb4 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 13 Nov 2024 00:34:58 -0600 Subject: [PATCH 3/5] [Experimental] Add Kleidi i8mm gemm kernels Add kernel level tests, with basic cross compilation support. Tested with S24 + r26c ``` [----------] 6 tests from test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.k_eq_gs_32 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.k_eq_gs_32 (0 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.large_k_n_gs32 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.large_k_n_gs32 (79 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.even_n_gs32 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.even_n_gs32 (28 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.k_eq_gs128 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.k_eq_gs128 (3 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.clamp_k_eq_gs128 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.clamp_k_eq_gs128 (3 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.m_clamp_k_eq_gs128 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.m_clamp_k_eq_gs128 (5 ms) [----------] 6 tests from test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm (121 ms total) [----------] 6 tests from test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.k_eq_gs_32 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.k_eq_gs_32 (0 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.large_k_n_gs32 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.large_k_n_gs32 (79 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.even_n_gs32 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.even_n_gs32 (28 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.k_eq_gs128 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.k_eq_gs128 (3 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.clamp_k_eq_gs128 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.clamp_k_eq_gs128 (3 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.m_clamp_k_eq_gs128 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.m_clamp_k_eq_gs128 (5 ms) [----------] 6 tests from test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm (121 ms total) ``` --- .../kernels/cpu/aarch64/CMakeLists.txt | 5 +- ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h | 120 +++++++++ ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h | 122 +++++++++ .../kernels/cpu/aarch64/tests/CMakeLists.txt | 25 +- .../cpu/aarch64/tests/build_and_run_tests.sh | 39 ++- .../kernels/cpu/aarch64/tests/test_linear.cpp | 236 +++++++++++++++++- 6 files changed, 540 insertions(+), 7 deletions(-) create mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index 6073425183..8751c38c81 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -4,8 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - -if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") +if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")) add_library( torchao_kernels_aarch64 ${CMAKE_CURRENT_SOURCE_DIR}/reduction/find_min_and_max.cpp @@ -25,7 +24,7 @@ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") # Temporarily exposing this to the parent scope until we wire # this up properly from the top level - set(TORCHAO_ENABLE_KLEIDI ON PARENT_SCOPE) + set(TORCHAO_BUILD_KLEIDI ON PARENT_SCOPE) target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai) endif() endif() diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h new file mode 100644 index 0000000000..699a8ce4f9 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h @@ -0,0 +1,120 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include + +namespace torchao::kernels::cpu::aarch64::kleidi { +namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { +namespace neon_i8mm_8x4x32 { + +const Ukernel get_ukernel() { + return Ukernel{ + .get_m_step = + kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_n_step = + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_mr = + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_nr = + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_kr = + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_sr = + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_lhs_packed_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_rhs_packed_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_dst_offset = + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_dst_size = + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .run_matmul = + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm}; +} + +size_t activation_data_size(int m, int k, int group_size) { + (void)group_size; // unused + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( + get_ukernel(), m, k); +} + +void prepare_activation_data( + void* activation_data, + int m, + int k, + int group_size, + const float* activations) { + (void)group_size; // unused + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( + get_ukernel(), activation_data, m, k, activations); +} + +size_t weight_data_size(int n, int k, int group_size) { + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( + get_ukernel(), n, k, group_size); +} + +void prepare_weight_data( + void* weight_data, + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros, + const float* bias) { + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( + get_ukernel(), + weight_data, + n, + k, + group_size, + weight_qvals, + weight_scales, + weight_zeros, + bias); +} + +void kernel( + float32_t* output, + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + float clamp_min, + float clamp_max) { + if (clamp_min == 0 && clamp_max == 0) { + clamp_min = std::numeric_limits::lowest(); + clamp_max = std::numeric_limits::max(); + } + + auto ukernel = get_ukernel(); + ukernel.run_matmul( + m, + n, + k, + group_size, + activation_data, + weight_data, + output, + /*dst_stride_row=*/n * sizeof(float), + /*dst_stride_col=*/sizeof(float), + clamp_min, + clamp_max); +} + +size_t get_preferred_alignement() { + return 16; +} +} // namespace neon_i8mm_8x4x32 +} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h new file mode 100644 index 0000000000..b61ce8bc3a --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h @@ -0,0 +1,122 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include + +#include + +namespace torchao::kernels::cpu::aarch64::kleidi { +namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { +namespace neon_i8mm_4x8x32 { + +const Ukernel get_ukernel() { + return Ukernel{ + .get_m_step = + kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_n_step = + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_mr = + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_nr = + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_kr = + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_sr = + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_lhs_packed_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_rhs_packed_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_dst_offset = + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_dst_size = + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .run_matmul = + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm}; +} + +size_t activation_data_size(int m, int k, int group_size) { + (void)group_size; // unused + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( + get_ukernel(), m, k); +} + +void prepare_activation_data( + void* activation_data, + int m, + int k, + int group_size, + const float* activations) { + (void)group_size; // unused + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( + get_ukernel(), activation_data, m, k, activations); +} + +size_t weight_data_size(int n, int k, int group_size) { + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( + get_ukernel(), n, k, group_size); +} + +void prepare_weight_data( + void* weight_data, + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros, + const float* bias) { + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( + get_ukernel(), + weight_data, + n, + k, + group_size, + weight_qvals, + weight_scales, + weight_zeros, + bias); +} + +void kernel( + float32_t* output, + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + float clamp_min, + float clamp_max) { + if (clamp_min == 0 && clamp_max == 0) { + clamp_min = std::numeric_limits::lowest(); + clamp_max = std::numeric_limits::max(); + } + + auto ukernel = get_ukernel(); + ukernel.run_matmul( + m, + n, + k, + group_size, + activation_data, + weight_data, + output, + /*dst_stride_row=*/n * sizeof(float), + /*dst_stride_col=*/sizeof(float), + clamp_min, + clamp_max); +} + +size_t get_preferred_alignement() { + return 16; +} + +} // namespace neon_i8mm_4x8x32 +} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index 3712a36250..e4cafdc97a 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -15,6 +15,11 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(googletest) +if (ANDROID_ABI) + # We are cross compiling, delay test discovery till runtime + set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST) +endif() + add_compile_options("-Wall" "-Werror") include(CMakePrintHelpers) @@ -35,13 +40,29 @@ endif() add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) -# The TORCHAO_ENABLE_KLEIDI cmake variable should be set by `torchao_kernels_aarch64" -if(TORCHAO_ENABLE_KLEIDI) +# The TORCHAO_BUILD_KLEIDI cmake variable should be set by `torchao_kernels_aarch64" +if(TORCHAO_BUILD_KLEIDI) add_compile_definitions(TORCHAO_ENABLE_KLEIDI) endif() +if(TORCHAO_BUILD_ARM_I8MM) + add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM) +endif() + enable_testing() +if (ANDROID_ABI) + # Given where we are today this is sufficent. But needs to be revisited. + # This is also needed for native builds, but keeping it only for cross builds + # for now given the hacky nature. + file(GLOB DOTPROD_SRC_FILES test*.cpp) + message(SRC_FILES: ${DOTPROD_SRC_FILES}) + set_property(SOURCE + ${DOTPROD_SRC_FILES} + APPEND_STRING PROPERTY + COMPILE_FLAGS " -march=armv8.2-a+dotprod ") +endif() + add_executable(test_quantization test_quantization.cpp) target_link_libraries( test_quantization diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 4394b02ece..767dca96ff 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -10,20 +10,57 @@ SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests +target=${1:-"native"} + IS_ARM64=0 +BUILD_ARM_I8MM=0 +EXTRA_ARGS="" +if [[ "${target}" == "android" ]]; then + if [[ -z ${ANDROID_NDK} ]]; then + echo "Need to set ANDROID_NDK env variable to build for Android"; + exit 1; + fi + android_abi=arm64-v8a + android_platform=28 # must be >=28 for aligned_alloc + IS_ARM64=1 + BUILD_ARM_I8MM=1 # Hardcoded for now + CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android} + toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" + if [[ -z ${toolchain_file} ]]; then + echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}" + exit 1; + fi + EXTRA_ARGS="\ + -DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \ + -DANDROID_ABI=${android_abi} \ + -DANDROID_PLATFORM=${android_platform} + " + echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}" +fi + hash arch; retval=$? if [[ ${retval} -eq 0 && $(arch) == "arm64" ]]; then IS_ARM64=1 fi -cmake -DCMAKE_BUILD_TYPE=Debug \ +cmake \ + ${EXTRA_ARGS} \ + -DCMAKE_BUILD_TYPE=Debug \ -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ -DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \ + -DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \ -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests \ -B ${CMAKE_OUT} cmake --build ${CMAKE_OUT} +echo "Successfully built tests." + +if [[ "${target}" != "native" ]]; then + echo "Skip running tests when cross compiling."; + exit 0; +fi + # Run ${CMAKE_OUT}/test_quantization ${CMAKE_OUT}/test_reduction diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index b28e3bfdc4..f68106c7e8 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -17,7 +17,11 @@ #ifdef TORCHAO_ENABLE_KLEIDI #include #include -#endif +#ifdef TORCHAO_ENABLE_ARM_I8MM +#include +#include +#endif // TORCHAO_ENABLE_ARM_I8MM +#endif // TORCHAO_ENABLE_KLEIDI float kTol = 0.0001; @@ -587,5 +591,235 @@ TEST( true /*has_clamp*/>( /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); } + +#ifdef TORCHAO_ENABLE_ARM_I8MM +template +void test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + int m, + int k, + int n, + int group_size) { + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + /*weight_nbit=*/4, + /*has_weight_zeros=*/false, + has_bias, + has_clamp, + /*round_weight_scales_to_bf16=*/true); + + using namespace torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_8x4x32; + + std::vector activation_data(activation_data_size(m, k, group_size)); + + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data(weight_data_size(n, k, group_size)); + + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data(), + /*bias=*/test_case.bias.data()); + + std::vector output(m * n); + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + k_eq_gs_32) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + large_k_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + even_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + m_clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +template +void test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + int m, + int k, + int n, + int group_size) { + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + /*weight_nbit=*/4, + /*has_weight_zeros=*/false, + has_bias, + has_clamp, + /*round_weight_scales_to_bf16=*/true); + + using namespace torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32; + + std::vector activation_data(activation_data_size(m, k, group_size)); + + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data(weight_data_size(n, k, group_size)); + + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data(), + /*bias=*/test_case.bias.data()); + + std::vector output(m * n); + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + k_eq_gs_32) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + large_k_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + even_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + m_clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); +} +#endif // TORCHAO_ENABLE_ARM_I8MM #endif // TORCHAO_ENABLE_KLEIDI #endif // defined(__aarch64__) || defined(__ARM_NEON) From 357244c732ce7faa5dd05651129cc07dd6d64203 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Mon, 18 Nov 2024 13:32:19 -0600 Subject: [PATCH 4/5] [Exeprimental] Kleidi: rename arg name for packing functions --- ...clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 8 ++++---- ...clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 8 ++++---- ...ul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h | 8 ++++---- ...ul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h | 8 ++++---- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h index cdac5829ec..dbda036efd 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -47,14 +47,14 @@ size_t activation_data_size(int m, int k, int group_size) { } void prepare_activation_data( - void* activation_data, + void* prepared_activation_data, int m, int k, int group_size, const float* activations) { (void)group_size; // unused kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), activation_data, m, k, activations); + get_ukernel(), prepared_activation_data, m, k, activations); } size_t weight_data_size(int n, int k, int group_size) { @@ -63,7 +63,7 @@ size_t weight_data_size(int n, int k, int group_size) { } void prepare_weight_data( - void* weight_data, + void* prepared_weight_data, int n, int k, int group_size, @@ -73,7 +73,7 @@ void prepare_weight_data( const float* bias) { kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( get_ukernel(), - weight_data, + prepared_weight_data, n, k, group_size, diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h index a739dc4c8b..d3d7bd55d9 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -45,7 +45,7 @@ size_t activation_data_size(int m, int k, int group_size) { } void prepare_activation_data( - void* activation_data, + void* prepared_activation_data, int m, int k, int group_size, @@ -53,7 +53,7 @@ void prepare_activation_data( (void) group_size; // unused kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( get_ukernel(), - activation_data, + prepared_activation_data, m, k, activations); @@ -64,7 +64,7 @@ size_t weight_data_size(int n, int k, int group_size) { } void prepare_weight_data( - void* weight_data, + void* prepared_weight_data, int n, int k, int group_size, @@ -74,7 +74,7 @@ void prepare_weight_data( const float* bias) { kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( get_ukernel(), - weight_data, + prepared_weight_data, n, k, group_size, diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h index 699a8ce4f9..4ef499d72c 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h @@ -45,14 +45,14 @@ size_t activation_data_size(int m, int k, int group_size) { } void prepare_activation_data( - void* activation_data, + void* prepared_activation_data, int m, int k, int group_size, const float* activations) { (void)group_size; // unused kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), activation_data, m, k, activations); + get_ukernel(), prepared_activation_data, m, k, activations); } size_t weight_data_size(int n, int k, int group_size) { @@ -61,7 +61,7 @@ size_t weight_data_size(int n, int k, int group_size) { } void prepare_weight_data( - void* weight_data, + void* prepared_weight_data, int n, int k, int group_size, @@ -71,7 +71,7 @@ void prepare_weight_data( const float* bias) { kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( get_ukernel(), - weight_data, + prepared_weight_data, n, k, group_size, diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h index b61ce8bc3a..d898cf3e5b 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h @@ -46,14 +46,14 @@ size_t activation_data_size(int m, int k, int group_size) { } void prepare_activation_data( - void* activation_data, + void* prepared_activation_data, int m, int k, int group_size, const float* activations) { (void)group_size; // unused kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), activation_data, m, k, activations); + get_ukernel(), prepared_activation_data, m, k, activations); } size_t weight_data_size(int n, int k, int group_size) { @@ -62,7 +62,7 @@ size_t weight_data_size(int n, int k, int group_size) { } void prepare_weight_data( - void* weight_data, + void* prepared_weight_data, int n, int k, int group_size, @@ -72,7 +72,7 @@ void prepare_weight_data( const float* bias) { kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( get_ukernel(), - weight_data, + prepared_weight_data, n, k, group_size, From c0ce31169947b4d09e2b841df23de029a5b4ba84 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Tue, 19 Nov 2024 22:49:53 -0600 Subject: [PATCH 5/5] [Experimental] Change kernel cmake_out dir to avoid conflict --- .../kernels/cpu/aarch64/tests/build_and_run_tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 767dca96ff..5c12d7184e 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -8,7 +8,7 @@ set -eu SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests +export CMAKE_OUT=/tmp/cmake-out/torch_ao/kernel_tests target=${1:-"native"}