Skip to content

Commit

Permalink
[WIP][Experimental] Kleidi add i8mm op level tests
Browse files Browse the repository at this point in the history
Still debugging i8mm gemm tests with multiple output tiles.

[----------] 59 tests from test_linear_8bit_act_xbit_weight
[ RUN      ] test_linear_8bit_act_xbit_weight.Standard
[       OK ] test_linear_8bit_act_xbit_weight.Standard (8 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.HasWeightZeros
[       OK ] test_linear_8bit_act_xbit_weight.HasWeightZeros (2 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.HasBias
[       OK ] test_linear_8bit_act_xbit_weight.HasBias (2 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.HasClamp
[       OK ] test_linear_8bit_act_xbit_weight.HasClamp (2 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.SmallDimension
[       OK ] test_linear_8bit_act_xbit_weight.SmallDimension (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KNotDivisibleByGroupSize
[       OK ] test_linear_8bit_act_xbit_weight.KNotDivisibleByGroupSize (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.GroupSizeNotDivisibleBy16
[       OK ] test_linear_8bit_act_xbit_weight.GroupSizeNotDivisibleBy16 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMV_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMV_dotprod_1x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMV_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMV_dotprod_1x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMV_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMV_dotprod_1x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMV_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMV_dotprod_1x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiWithBiasGEMV_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiWithBiasGEMV_dotprod_1x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMV_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMV_dotprod_1x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMM_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMM_dotprod_1x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMM_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMM_dotprod_1x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMM_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMM_dotprod_1x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMM_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMM_dotprod_1x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiHasBiasGEMM_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiHasBiasGEMM_dotprod_1x4x32 (2 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiLargeGEMM_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiLargeGEMM_dotprod_1x4x32 (205 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMM_dotprod_1x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMM_dotprod_1x4x32 (1 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMV_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMV_dotprod_1x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMV_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMV_dotprod_1x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMV_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMV_dotprod_1x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMV_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMV_dotprod_1x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiWithBiasGEMV_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiWithBiasGEMV_dotprod_1x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMV_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMV_dotprod_1x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMM_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMM_dotprod_1x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMM_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMM_dotprod_1x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMM_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMM_dotprod_1x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMM_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMM_dotprod_1x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiHasBiasGEMM_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiHasBiasGEMM_dotprod_1x8x32 (1 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiLargeGEMM_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiLargeGEMM_dotprod_1x8x32 (197 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMM_dotprod_1x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMM_dotprod_1x8x32 (1 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMV_i8mm_4x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMV_i8mm_4x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMV_i8mm_4x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMV_i8mm_4x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMV_i8mm_4x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMV_i8mm_4x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMV_i8mm_4x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMV_i8mm_4x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiWithBiasGEMV_i8mm_4x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiWithBiasGEMV_i8mm_4x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMV_i8mm_4x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMV_i8mm_4x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMM_i8mm_4x8x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMM_i8mm_4x8x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMM_i8mm_4x8x32
[  FAILED  ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMM_i8mm_4x8x32 (16 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMM_i8mm_4x8x32
[  FAILED  ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMM_i8mm_4x8x32 (55 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMM_i8mm_4x8x32
[  FAILED  ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMM_i8mm_4x8x32 (13 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiHasBiasGEMM_i8mm_4x8x32
[  FAILED  ] test_linear_8bit_act_xbit_weight.KleidiHasBiasGEMM_i8mm_4x8x32 (16 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiLargeGEMM_i8mm_4x8x32
[  FAILED  ] test_linear_8bit_act_xbit_weight.KleidiLargeGEMM_i8mm_4x8x32 (271 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMM_i8mm_4x8x32
[  FAILED  ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMM_i8mm_4x8x32 (45 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMV_i8mm_8x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMV_i8mm_8x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMV_i8mm_8x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMV_i8mm_8x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMV_i8mm_8x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMV_i8mm_8x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMV_i8mm_8x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMV_i8mm_8x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiWithBiasGEMV_i8mm_8x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiWithBiasGEMV_i8mm_8x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMV_i8mm_8x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMV_i8mm_8x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMM_i8mm_8x4x32
[       OK ] test_linear_8bit_act_xbit_weight.KleidiTinyGEMM_i8mm_8x4x32 (0 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMM_i8mm_8x4x32
[  FAILED  ] test_linear_8bit_act_xbit_weight.KleidiSmallGEMM_i8mm_8x4x32 (9 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMM_i8mm_8x4x32
[  FAILED  ] test_linear_8bit_act_xbit_weight.KleidiStandardGEMM_i8mm_8x4x32 (55 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMM_i8mm_8x4x32
[  FAILED  ] test_linear_8bit_act_xbit_weight.KleidiHasClampGEMM_i8mm_8x4x32 (7 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiHasBiasGEMM_i8mm_8x4x32
[  FAILED  ] test_linear_8bit_act_xbit_weight.KleidiHasBiasGEMM_i8mm_8x4x32 (16 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiLargeGEMM_i8mm_8x4x32
[  FAILED  ] test_linear_8bit_act_xbit_weight.KleidiLargeGEMM_i8mm_8x4x32 (270 ms)
[ RUN      ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMM_i8mm_8x4x32
[  FAILED  ] test_linear_8bit_act_xbit_weight.KleidiLargerGroupGEMM_i8mm_8x4x32 (42 ms)
  • Loading branch information
digantdesai committed Nov 19, 2024
1 parent 357244c commit 8504177
Show file tree
Hide file tree
Showing 4 changed files with 869 additions and 41 deletions.
22 changes: 22 additions & 0 deletions torchao/experimental/ops/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,34 @@ if(TORCHAO_BUILD_KLEIDIAI)
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
endif()

if(TORCHAO_BUILD_ARM_I8MM)
add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM)
endif()

if (ANDROID_ABI)
# We are cross compiling, delay test discovery till runtime
set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST)
endif()

include_directories(${TORCHAO_INCLUDE_DIRS})

set(TORCHAO_PARALLEL_BACKEND "test_dummy")
add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64)

include(${TORCHAO_ROOT}/Utils.cmake)

if (ANDROID_ABI)
# Given where we are today this is sufficent. But needs to be revisited.
# This is also needed for native builds, but keeping it only for cross builds
# for now given the hacky nature.
file(GLOB DOTPROD_SRC_FILES test*.cpp)
message(SRC_FILES: ${DOTPROD_SRC_FILES})
set_property(SOURCE
${DOTPROD_SRC_FILES}
APPEND_STRING PROPERTY
COMPILE_FLAGS " -march=armv8.2-a+dotprod ")
endif()

add_executable(
test_linear_8bit_act_xbit_weight
test_linear_8bit_act_xbit_weight.cpp
Expand Down
41 changes: 39 additions & 2 deletions torchao/experimental/ops/tests/build_and_run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,57 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

target=${1:-"native"}
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests

IS_ARM64=0
BUILD_ARM_I8MM=0
EXTRA_ARGS=""
if [[ "${target}" == "android" ]]; then
if [[ -z ${ANDROID_NDK} ]]; then
echo "Need to set ANDROID_NDK env variable to build for Android";
exit 1;
fi
android_abi=arm64-v8a
android_platform=28 # must be >=28 for aligned_alloc
IS_ARM64=1
BUILD_ARM_I8MM=1 # Hardcoded for now
CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android}
toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake"
if [[ -z ${toolchain_file} ]]; then
echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}"
exit 1;
fi
EXTRA_ARGS="\
-DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \
-DANDROID_ABI=${android_abi} \
-DANDROID_PLATFORM=${android_platform}
"
echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}"
fi

hash arch; retval=$?
if [[ ${retval} -eq 0 && $(arch) == "arm64" ]]; then
IS_ARM64=1
fi

export CMAKE_OUT=/tmp/cmake-out/torchao/tests
cmake \
-DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
${EXTRA_ARGS} \
-DCMAKE_BUILD_TYPE=Debug \
-DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \
-DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \
-S . \
-B ${CMAKE_OUT}

cmake --build ${CMAKE_OUT}

echo "Successfully built tests."

if [[ "${target}" != "native" ]]; then
echo "Skip running tests when cross compiling.";
exit 0;
fi

# Run
${CMAKE_OUT}/test_linear_8bit_act_xbit_weight
181 changes: 181 additions & 0 deletions torchao/experimental/ops/tests/generate_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.

# Simple script to generate test cases for the torchao ops
from string import Template

kleidi_template = Template("""
// ${kernel} tests
${prologue}
TEST(test_linear_8bit_act_xbit_weight, KleidiTinyGEMV_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/1, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config);
}
TEST(test_linear_8bit_act_xbit_weight, KleidiSmallGEMV_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/1, /*n=*/12, /*k=*/32, /*group_size=*/32, &ukernel_config);
}
TEST(test_linear_8bit_act_xbit_weight, KleidiStandardGEMV_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/1, /*n=*/20, /*k=*/32, /*group_size=*/32, &ukernel_config);
}
TEST(test_linear_8bit_act_xbit_weight, KleidiHasClampGEMV_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
true /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/1, /*n=*/10, /*k=*/32 * 2, /*group_size=*/32, &ukernel_config);
}
TEST(test_linear_8bit_act_xbit_weight, KleidiWithBiasGEMV_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
true /*has_bias*/,
true /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/1, /*n=*/18, /*k=*/32 * 3, /*group_size=*/32, &ukernel_config);
}
TEST(test_linear_8bit_act_xbit_weight, KleidiLargerGroupGEMV_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
true /*has_bias*/,
true /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/1, /*n=*/18, /*k=*/32 * 4, /*group_size=*/32 * 2, &ukernel_config);
}
TEST(test_linear_8bit_act_xbit_weight, KleidiTinyGEMM_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/2, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config);
}
TEST(test_linear_8bit_act_xbit_weight, KleidiSmallGEMM_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/16, /*n=*/12, /*k=*/32, /*group_size=*/32, &ukernel_config);
}
TEST(test_linear_8bit_act_xbit_weight, KleidiStandardGEMM_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/13, /*n=*/20, /*k=*/32, /*group_size=*/32, &ukernel_config);
}
TEST(test_linear_8bit_act_xbit_weight, KleidiHasClampGEMM_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
true /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/17, /*n=*/10, /*k=*/32 * 2, /*group_size=*/32, &ukernel_config);
}
TEST(test_linear_8bit_act_xbit_weight, KleidiHasBiasGEMM_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
true /*has_bias*/,
true /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/23, /*n=*/18, /*k=*/32 * 3, /*group_size=*/32, &ukernel_config);
}
TEST(test_linear_8bit_act_xbit_weight, KleidiLargeGEMM_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
true /*has_bias*/,
true /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/231, /*n=*/188, /*k=*/32 * 13, /*group_size=*/32, &ukernel_config);
}
TEST(test_linear_8bit_act_xbit_weight, KleidiLargerGroupGEMM_${kernel}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
true /*has_bias*/,
true /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/16, /*n=*/18, /*k=*/32 * 4, /*group_size=*/32 * 2, &ukernel_config);
}
${epilogue}
""")

kleidi_kernels = [
"dotprod_1x4x32",
"dotprod_1x8x32",
"i8mm_4x8x32",
"i8mm_8x4x32",
]

print("/* Generated by generate_tests.py */")
print("/* Do not modify */")
print()
print("#if defined(TORCHAO_ENABLE_KLEIDI)")
for kernel in kleidi_kernels:
prologue, epilogue = "", ""
if "i8mm" in kernel:
prologue = "#if defined(TORCHAO_ENABLE_ARM_I8MM)"
epilogue = "#endif // TORCHAO_ENABLE_ARM_I8MM"

d = {
"prologue": prologue,
"kernel": kernel,
"epilogue": epilogue,
}

print(kleidi_template.safe_substitute(d))
print("#endif // TORCHAO_ENABLE_KLEIDI")
Loading

0 comments on commit 8504177

Please sign in to comment.