diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index 39045c6dc8d..0f67eaedd5e 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -2371,13 +2371,15 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; __m512 _scales = _mm512_loadu_ps((const float*)scales + i + ii); +#if __AVX512VNNI__ + __m512i _v127 = _mm512_set1_epi8(127); +#endif // __AVX512VNNI__ if (elempack == 16) { int kk = 0; #if __AVX512VNNI__ __m512i _w_shift = _mm512_setzero_epi32(); - __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 3 < max_kk; kk += 4) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -2424,11 +2426,11 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i __m128i _pp1 = float2int8_avx512(_p1); // transpose16x2_epi8 - __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); - __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); - _mm_storeu_si128((__m128i*)pp, _tt0); - _mm_storeu_si128((__m128i*)(pp + 16), _tt1); + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 16), _t1); pp += 32; p0 += 32; @@ -2452,7 +2454,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int kk = 0; #if __AVX512VNNI__ __m512i _w_shift = _mm512_setzero_epi32(); - __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 3 < max_kk; kk += 4) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -2537,7 +2538,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int kk = 0; #if __AVX512VNNI__ __m512i _w_shift = _mm512_setzero_epi32(); - __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 3 < max_kk; kk += 4) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -2634,7 +2634,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int kk = 0; #if __AVX512VNNI__ __m512i _w_shift = _mm512_setzero_epi32(); - __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 3 < max_kk; kk += 4) { __m128 _p0 = _mm_loadu_ps(p0); @@ -4295,13 +4294,15 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; __m512 _scales = _mm512_loadu_ps((const float*)scales + i + ii); +#if __AVX512VNNI__ + __m512i _v127 = _mm512_set1_epi8(127); +#endif if (elempack == 16) { int kk = 0; #if __AVX512VNNI__ __m512i _w_shift = _mm512_setzero_epi32(); - __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 15 < max_kk; kk += 16) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -4490,7 +4491,6 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int int kk = 0; #if __AVX512VNNI__ __m512i _w_shift = _mm512_setzero_epi32(); - __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 7 < max_kk; kk += 8) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -4603,7 +4603,6 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int int kk = 0; #if __AVX512VNNI__ __m512i _w_shift = _mm512_setzero_epi32(); - __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 3 < max_kk; kk += 4) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -4676,7 +4675,6 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int int kk = 0; #if __AVX512VNNI__ __m512i _w_shift = _mm512_setzero_epi32(); - __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 3 < max_kk; kk += 4) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -4723,11 +4721,11 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int __m128i _pp1 = float2int8_avx512(_p1); // transpose16x2_epi8 - __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); - __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); - _mm_storeu_si128((__m128i*)pp, _tt0); - _mm_storeu_si128((__m128i*)(pp + 16), _tt1); + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 16), _t1); pp += 32; p0 += A_hstep * 2; @@ -6936,135 +6934,75 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i { const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; + __m512 _scale = _mm512_set1_ps(scale); +#if __AVX512VNNI__ + __m512i _v127 = _mm512_set1_epi8(127); +#endif // __AVX512VNNI__ + if (elempack == 16) { int kk = 0; #if __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[16] * scale) + 127; - pp[2] = float2int8(p0[32] * scale) + 127; - pp[3] = float2int8(p0[48] * scale) + 127; - pp[4] = float2int8(p0[1] * scale) + 127; - pp[5] = float2int8(p0[17] * scale) + 127; - pp[6] = float2int8(p0[33] * scale) + 127; - pp[7] = float2int8(p0[49] * scale) + 127; - pp[8] = float2int8(p0[2] * scale) + 127; - pp[9] = float2int8(p0[18] * scale) + 127; - pp[10] = float2int8(p0[34] * scale) + 127; - pp[11] = float2int8(p0[50] * scale) + 127; - pp[12] = float2int8(p0[3] * scale) + 127; - pp[13] = float2int8(p0[19] * scale) + 127; - pp[14] = float2int8(p0[35] * scale) + 127; - pp[15] = float2int8(p0[51] * scale) + 127; - pp[16] = float2int8(p0[4] * scale) + 127; - pp[17] = float2int8(p0[20] * scale) + 127; - pp[18] = float2int8(p0[36] * scale) + 127; - pp[19] = float2int8(p0[52] * scale) + 127; - pp[20] = float2int8(p0[5] * scale) + 127; - pp[21] = float2int8(p0[21] * scale) + 127; - pp[22] = float2int8(p0[37] * scale) + 127; - pp[23] = float2int8(p0[53] * scale) + 127; - pp[24] = float2int8(p0[6] * scale) + 127; - pp[25] = float2int8(p0[22] * scale) + 127; - pp[26] = float2int8(p0[38] * scale) + 127; - pp[27] = float2int8(p0[54] * scale) + 127; - pp[28] = float2int8(p0[7] * scale) + 127; - pp[29] = float2int8(p0[23] * scale) + 127; - pp[30] = float2int8(p0[39] * scale) + 127; - pp[31] = float2int8(p0[55] * scale) + 127; - pp[32] = float2int8(p0[8] * scale) + 127; - pp[33] = float2int8(p0[24] * scale) + 127; - pp[34] = float2int8(p0[40] * scale) + 127; - pp[35] = float2int8(p0[56] * scale) + 127; - pp[36] = float2int8(p0[9] * scale) + 127; - pp[37] = float2int8(p0[25] * scale) + 127; - pp[38] = float2int8(p0[41] * scale) + 127; - pp[39] = float2int8(p0[57] * scale) + 127; - pp[40] = float2int8(p0[10] * scale) + 127; - pp[41] = float2int8(p0[26] * scale) + 127; - pp[42] = float2int8(p0[42] * scale) + 127; - pp[43] = float2int8(p0[58] * scale) + 127; - pp[44] = float2int8(p0[11] * scale) + 127; - pp[45] = float2int8(p0[27] * scale) + 127; - pp[46] = float2int8(p0[43] * scale) + 127; - pp[47] = float2int8(p0[59] * scale) + 127; - pp[48] = float2int8(p0[12] * scale) + 127; - pp[49] = float2int8(p0[28] * scale) + 127; - pp[50] = float2int8(p0[44] * scale) + 127; - pp[51] = float2int8(p0[60] * scale) + 127; - pp[52] = float2int8(p0[13] * scale) + 127; - pp[53] = float2int8(p0[29] * scale) + 127; - pp[54] = float2int8(p0[45] * scale) + 127; - pp[55] = float2int8(p0[61] * scale) + 127; - pp[56] = float2int8(p0[14] * scale) + 127; - pp[57] = float2int8(p0[30] * scale) + 127; - pp[58] = float2int8(p0[46] * scale) + 127; - pp[59] = float2int8(p0[62] * scale) + 127; - pp[60] = float2int8(p0[15] * scale) + 127; - pp[61] = float2int8(p0[31] * scale) + 127; - pp[62] = float2int8(p0[47] * scale) + 127; - pp[63] = float2int8(p0[63] * scale) + 127; + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _pp = _mm512_add_epi8(_pp, _v127); + + _mm512_storeu_si512((__m512i*)pp, _pp); + pp += 64; p0 += 64; } #endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[16] * scale); - pp[2] = float2int8(p0[1] * scale); - pp[3] = float2int8(p0[17] * scale); - pp[4] = float2int8(p0[2] * scale); - pp[5] = float2int8(p0[18] * scale); - pp[6] = float2int8(p0[3] * scale); - pp[7] = float2int8(p0[19] * scale); - pp[8] = float2int8(p0[4] * scale); - pp[9] = float2int8(p0[20] * scale); - pp[10] = float2int8(p0[5] * scale); - pp[11] = float2int8(p0[21] * scale); - pp[12] = float2int8(p0[6] * scale); - pp[13] = float2int8(p0[22] * scale); - pp[14] = float2int8(p0[7] * scale); - pp[15] = float2int8(p0[23] * scale); - pp[16 + 0] = float2int8(p0[8] * scale); - pp[16 + 1] = float2int8(p0[24] * scale); - pp[16 + 2] = float2int8(p0[9] * scale); - pp[16 + 3] = float2int8(p0[25] * scale); - pp[16 + 4] = float2int8(p0[10] * scale); - pp[16 + 5] = float2int8(p0[26] * scale); - pp[16 + 6] = float2int8(p0[11] * scale); - pp[16 + 7] = float2int8(p0[27] * scale); - pp[16 + 8] = float2int8(p0[12] * scale); - pp[16 + 9] = float2int8(p0[28] * scale); - pp[16 + 10] = float2int8(p0[13] * scale); - pp[16 + 11] = float2int8(p0[29] * scale); - pp[16 + 12] = float2int8(p0[14] * scale); - pp[16 + 13] = float2int8(p0[30] * scale); - pp[16 + 14] = float2int8(p0[15] * scale); - pp[16 + 15] = float2int8(p0[31] * scale); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + // transpose16x2_epi8 + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 16), _t1); + pp += 32; p0 += 32; } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); - pp[4] = float2int8(p0[4] * scale); - pp[5] = float2int8(p0[5] * scale); - pp[6] = float2int8(p0[6] * scale); - pp[7] = float2int8(p0[7] * scale); - pp[8] = float2int8(p0[8] * scale); - pp[9] = float2int8(p0[9] * scale); - pp[10] = float2int8(p0[10] * scale); - pp[11] = float2int8(p0[11] * scale); - pp[12] = float2int8(p0[12] * scale); - pp[13] = float2int8(p0[13] * scale); - pp[14] = float2int8(p0[14] * scale); - pp[15] = float2int8(p0[15] * scale); + __m512 _p = _mm512_loadu_ps(p0); + + _p = _mm512_mul_ps(_p, _scale); + + __m128i _pp = float2int8_avx512(_p); + + _mm_storeu_si128((__m128i*)pp, _pp); + pp += 16; p0 += 16; } @@ -7075,131 +7013,73 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[8] * scale) + 127; - pp[2] = float2int8(p0[16] * scale) + 127; - pp[3] = float2int8(p0[24] * scale) + 127; - pp[4] = float2int8(p0[1] * scale) + 127; - pp[5] = float2int8(p0[9] * scale) + 127; - pp[6] = float2int8(p0[17] * scale) + 127; - pp[7] = float2int8(p0[25] * scale) + 127; - pp[8] = float2int8(p0[2] * scale) + 127; - pp[9] = float2int8(p0[10] * scale) + 127; - pp[10] = float2int8(p0[18] * scale) + 127; - pp[11] = float2int8(p0[26] * scale) + 127; - pp[12] = float2int8(p0[3] * scale) + 127; - pp[13] = float2int8(p0[11] * scale) + 127; - pp[14] = float2int8(p0[19] * scale) + 127; - pp[15] = float2int8(p0[27] * scale) + 127; - pp[16] = float2int8(p0[4] * scale) + 127; - pp[17] = float2int8(p0[12] * scale) + 127; - pp[18] = float2int8(p0[20] * scale) + 127; - pp[19] = float2int8(p0[28] * scale) + 127; - pp[20] = float2int8(p0[5] * scale) + 127; - pp[21] = float2int8(p0[13] * scale) + 127; - pp[22] = float2int8(p0[21] * scale) + 127; - pp[23] = float2int8(p0[29] * scale) + 127; - pp[24] = float2int8(p0[6] * scale) + 127; - pp[25] = float2int8(p0[14] * scale) + 127; - pp[26] = float2int8(p0[22] * scale) + 127; - pp[27] = float2int8(p0[30] * scale) + 127; - pp[28] = float2int8(p0[7] * scale) + 127; - pp[29] = float2int8(p0[15] * scale) + 127; - pp[30] = float2int8(p0[23] * scale) + 127; - pp[31] = float2int8(p0[31] * scale) + 127; + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + B_hstep * 8); + __m512 _p3 = _mm512_loadu_ps(p0 + B_hstep * 8 + 16); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + __m128i _t2 = _mm_unpacklo_epi8(_pp2, _pp3); + __m128i _t3 = _mm_unpackhi_epi8(_pp2, _pp3); + _pp0 = _mm_unpacklo_epi8(_t0, _t1); + _pp1 = _mm_unpackhi_epi8(_t0, _t1); + _pp2 = _mm_unpacklo_epi8(_t2, _t3); + _pp3 = _mm_unpackhi_epi8(_t2, _t3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _pp = _mm512_add_epi8(_pp, _v127); + + _mm512_storeu_si512((__m512i*)pp, _pp); - pp[32 + 0] = float2int8(p0[B_hstep * 8 + 0] * scale) + 127; - pp[32 + 1] = float2int8(p0[B_hstep * 8 + 8] * scale) + 127; - pp[32 + 2] = float2int8(p0[B_hstep * 8 + 16] * scale) + 127; - pp[32 + 3] = float2int8(p0[B_hstep * 8 + 24] * scale) + 127; - pp[32 + 4] = float2int8(p0[B_hstep * 8 + 1] * scale) + 127; - pp[32 + 5] = float2int8(p0[B_hstep * 8 + 9] * scale) + 127; - pp[32 + 6] = float2int8(p0[B_hstep * 8 + 17] * scale) + 127; - pp[32 + 7] = float2int8(p0[B_hstep * 8 + 25] * scale) + 127; - pp[32 + 8] = float2int8(p0[B_hstep * 8 + 2] * scale) + 127; - pp[32 + 9] = float2int8(p0[B_hstep * 8 + 10] * scale) + 127; - pp[32 + 10] = float2int8(p0[B_hstep * 8 + 18] * scale) + 127; - pp[32 + 11] = float2int8(p0[B_hstep * 8 + 26] * scale) + 127; - pp[32 + 12] = float2int8(p0[B_hstep * 8 + 3] * scale) + 127; - pp[32 + 13] = float2int8(p0[B_hstep * 8 + 11] * scale) + 127; - pp[32 + 14] = float2int8(p0[B_hstep * 8 + 19] * scale) + 127; - pp[32 + 15] = float2int8(p0[B_hstep * 8 + 27] * scale) + 127; - pp[32 + 16] = float2int8(p0[B_hstep * 8 + 4] * scale) + 127; - pp[32 + 17] = float2int8(p0[B_hstep * 8 + 12] * scale) + 127; - pp[32 + 18] = float2int8(p0[B_hstep * 8 + 20] * scale) + 127; - pp[32 + 19] = float2int8(p0[B_hstep * 8 + 28] * scale) + 127; - pp[32 + 20] = float2int8(p0[B_hstep * 8 + 5] * scale) + 127; - pp[32 + 21] = float2int8(p0[B_hstep * 8 + 13] * scale) + 127; - pp[32 + 22] = float2int8(p0[B_hstep * 8 + 21] * scale) + 127; - pp[32 + 23] = float2int8(p0[B_hstep * 8 + 29] * scale) + 127; - pp[32 + 24] = float2int8(p0[B_hstep * 8 + 6] * scale) + 127; - pp[32 + 25] = float2int8(p0[B_hstep * 8 + 14] * scale) + 127; - pp[32 + 26] = float2int8(p0[B_hstep * 8 + 22] * scale) + 127; - pp[32 + 27] = float2int8(p0[B_hstep * 8 + 30] * scale) + 127; - pp[32 + 28] = float2int8(p0[B_hstep * 8 + 7] * scale) + 127; - pp[32 + 29] = float2int8(p0[B_hstep * 8 + 15] * scale) + 127; - pp[32 + 30] = float2int8(p0[B_hstep * 8 + 23] * scale) + 127; - pp[32 + 31] = float2int8(p0[B_hstep * 8 + 31] * scale) + 127; pp += 64; p0 += 32; } #endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[8] * scale); - pp[2] = float2int8(p0[1] * scale); - pp[3] = float2int8(p0[9] * scale); - pp[4] = float2int8(p0[2] * scale); - pp[5] = float2int8(p0[10] * scale); - pp[6] = float2int8(p0[3] * scale); - pp[7] = float2int8(p0[11] * scale); - pp[8] = float2int8(p0[4] * scale); - pp[9] = float2int8(p0[12] * scale); - pp[10] = float2int8(p0[5] * scale); - pp[11] = float2int8(p0[13] * scale); - pp[12] = float2int8(p0[6] * scale); - pp[13] = float2int8(p0[14] * scale); - pp[14] = float2int8(p0[7] * scale); - pp[15] = float2int8(p0[15] * scale); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + B_hstep * 8); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15); + _pp0 = _mm_shuffle_epi8(_pp0, _si); + _pp1 = _mm_shuffle_epi8(_pp1, _si); + + _mm_storeu_si128((__m128i*)pp, _pp0); + _mm_storeu_si128((__m128i*)(pp + 16), _pp1); - pp[16 + 0] = float2int8(p0[B_hstep * 8 + 0] * scale); - pp[16 + 1] = float2int8(p0[B_hstep * 8 + 8] * scale); - pp[16 + 2] = float2int8(p0[B_hstep * 8 + 1] * scale); - pp[16 + 3] = float2int8(p0[B_hstep * 8 + 9] * scale); - pp[16 + 4] = float2int8(p0[B_hstep * 8 + 2] * scale); - pp[16 + 5] = float2int8(p0[B_hstep * 8 + 10] * scale); - pp[16 + 6] = float2int8(p0[B_hstep * 8 + 3] * scale); - pp[16 + 7] = float2int8(p0[B_hstep * 8 + 11] * scale); - pp[16 + 8] = float2int8(p0[B_hstep * 8 + 4] * scale); - pp[16 + 9] = float2int8(p0[B_hstep * 8 + 12] * scale); - pp[16 + 10] = float2int8(p0[B_hstep * 8 + 5] * scale); - pp[16 + 11] = float2int8(p0[B_hstep * 8 + 13] * scale); - pp[16 + 12] = float2int8(p0[B_hstep * 8 + 6] * scale); - pp[16 + 13] = float2int8(p0[B_hstep * 8 + 14] * scale); - pp[16 + 14] = float2int8(p0[B_hstep * 8 + 7] * scale); - pp[16 + 15] = float2int8(p0[B_hstep * 8 + 15] * scale); pp += 32; p0 += 16; } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); - pp[4] = float2int8(p0[4] * scale); - pp[5] = float2int8(p0[5] * scale); - pp[6] = float2int8(p0[6] * scale); - pp[7] = float2int8(p0[7] * scale); - pp[8] = float2int8(p0[B_hstep * 8 + 0] * scale); - pp[9] = float2int8(p0[B_hstep * 8 + 1] * scale); - pp[10] = float2int8(p0[B_hstep * 8 + 2] * scale); - pp[11] = float2int8(p0[B_hstep * 8 + 3] * scale); - pp[12] = float2int8(p0[B_hstep * 8 + 4] * scale); - pp[13] = float2int8(p0[B_hstep * 8 + 5] * scale); - pp[14] = float2int8(p0[B_hstep * 8 + 6] * scale); - pp[15] = float2int8(p0[B_hstep * 8 + 7] * scale); + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + B_hstep * 8); + + __m512 _p = combine8x2_ps(_p0, _p1); + _p = _mm512_mul_ps(_p, _scale); + + __m128i _pp = float2int8_avx512(_p); + + _mm_storeu_si128((__m128i*)pp, _pp); + pp += 16; p0 += 8; } @@ -7210,72 +7090,29 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[4] * scale) + 127; - pp[2] = float2int8(p0[8] * scale) + 127; - pp[3] = float2int8(p0[12] * scale) + 127; - pp[4] = float2int8(p0[1] * scale) + 127; - pp[5] = float2int8(p0[5] * scale) + 127; - pp[6] = float2int8(p0[9] * scale) + 127; - pp[7] = float2int8(p0[13] * scale) + 127; - pp[8] = float2int8(p0[2] * scale) + 127; - pp[9] = float2int8(p0[6] * scale) + 127; - pp[10] = float2int8(p0[10] * scale) + 127; - pp[11] = float2int8(p0[14] * scale) + 127; - pp[12] = float2int8(p0[3] * scale) + 127; - pp[13] = float2int8(p0[7] * scale) + 127; - pp[14] = float2int8(p0[11] * scale) + 127; - pp[15] = float2int8(p0[15] * scale) + 127; - pp[16 + 0] = float2int8(p0[B_hstep * 4 + 0] * scale) + 127; - pp[16 + 1] = float2int8(p0[B_hstep * 4 + 4] * scale) + 127; - pp[16 + 2] = float2int8(p0[B_hstep * 4 + 8] * scale) + 127; - pp[16 + 3] = float2int8(p0[B_hstep * 4 + 12] * scale) + 127; - pp[16 + 4] = float2int8(p0[B_hstep * 4 + 1] * scale) + 127; - pp[16 + 5] = float2int8(p0[B_hstep * 4 + 5] * scale) + 127; - pp[16 + 6] = float2int8(p0[B_hstep * 4 + 9] * scale) + 127; - pp[16 + 7] = float2int8(p0[B_hstep * 4 + 13] * scale) + 127; - pp[16 + 8] = float2int8(p0[B_hstep * 4 + 2] * scale) + 127; - pp[16 + 9] = float2int8(p0[B_hstep * 4 + 6] * scale) + 127; - pp[16 + 10] = float2int8(p0[B_hstep * 4 + 10] * scale) + 127; - pp[16 + 11] = float2int8(p0[B_hstep * 4 + 14] * scale) + 127; - pp[16 + 12] = float2int8(p0[B_hstep * 4 + 3] * scale) + 127; - pp[16 + 13] = float2int8(p0[B_hstep * 4 + 7] * scale) + 127; - pp[16 + 14] = float2int8(p0[B_hstep * 4 + 11] * scale) + 127; - pp[16 + 15] = float2int8(p0[B_hstep * 4 + 15] * scale) + 127; - - pp[32 + 0] = float2int8(p0[B_hstep * 8 + 0] * scale) + 127; - pp[32 + 1] = float2int8(p0[B_hstep * 8 + 4] * scale) + 127; - pp[32 + 2] = float2int8(p0[B_hstep * 8 + 8] * scale) + 127; - pp[32 + 3] = float2int8(p0[B_hstep * 8 + 12] * scale) + 127; - pp[32 + 4] = float2int8(p0[B_hstep * 8 + 1] * scale) + 127; - pp[32 + 5] = float2int8(p0[B_hstep * 8 + 5] * scale) + 127; - pp[32 + 6] = float2int8(p0[B_hstep * 8 + 9] * scale) + 127; - pp[32 + 7] = float2int8(p0[B_hstep * 8 + 13] * scale) + 127; - pp[32 + 8] = float2int8(p0[B_hstep * 8 + 2] * scale) + 127; - pp[32 + 9] = float2int8(p0[B_hstep * 8 + 6] * scale) + 127; - pp[32 + 10] = float2int8(p0[B_hstep * 8 + 10] * scale) + 127; - pp[32 + 11] = float2int8(p0[B_hstep * 8 + 14] * scale) + 127; - pp[32 + 12] = float2int8(p0[B_hstep * 8 + 3] * scale) + 127; - pp[32 + 13] = float2int8(p0[B_hstep * 8 + 7] * scale) + 127; - pp[32 + 14] = float2int8(p0[B_hstep * 8 + 11] * scale) + 127; - pp[32 + 15] = float2int8(p0[B_hstep * 8 + 15] * scale) + 127; - - pp[48 + 0] = float2int8(p0[B_hstep * 12 + 0] * scale) + 127; - pp[48 + 1] = float2int8(p0[B_hstep * 12 + 4] * scale) + 127; - pp[48 + 2] = float2int8(p0[B_hstep * 12 + 8] * scale) + 127; - pp[48 + 3] = float2int8(p0[B_hstep * 12 + 12] * scale) + 127; - pp[48 + 4] = float2int8(p0[B_hstep * 12 + 1] * scale) + 127; - pp[48 + 5] = float2int8(p0[B_hstep * 12 + 5] * scale) + 127; - pp[48 + 6] = float2int8(p0[B_hstep * 12 + 9] * scale) + 127; - pp[48 + 7] = float2int8(p0[B_hstep * 12 + 13] * scale) + 127; - pp[48 + 8] = float2int8(p0[B_hstep * 12 + 2] * scale) + 127; - pp[48 + 9] = float2int8(p0[B_hstep * 12 + 6] * scale) + 127; - pp[48 + 10] = float2int8(p0[B_hstep * 12 + 10] * scale) + 127; - pp[48 + 11] = float2int8(p0[B_hstep * 12 + 14] * scale) + 127; - pp[48 + 12] = float2int8(p0[B_hstep * 12 + 3] * scale) + 127; - pp[48 + 13] = float2int8(p0[B_hstep * 12 + 7] * scale) + 127; - pp[48 + 14] = float2int8(p0[B_hstep * 12 + 11] * scale) + 127; - pp[48 + 15] = float2int8(p0[B_hstep * 12 + 15] * scale) + 127; + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + B_hstep * 4); + __m512 _p2 = _mm512_loadu_ps(p0 + B_hstep * 8); + __m512 _p3 = _mm512_loadu_ps(p0 + B_hstep * 12); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _pp = _mm512_add_epi8(_pp, _v127); + + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm512_shuffle_epi8(_pp, _mm512_broadcast_i32x4(_si)); + + _mm512_storeu_si512((__m512i*)pp, _pp); pp += 64; p0 += 16; @@ -7283,62 +7120,44 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[4] * scale); - pp[2] = float2int8(p0[1] * scale); - pp[3] = float2int8(p0[5] * scale); - pp[4] = float2int8(p0[2] * scale); - pp[5] = float2int8(p0[6] * scale); - pp[6] = float2int8(p0[3] * scale); - pp[7] = float2int8(p0[7] * scale); - pp[8] = float2int8(p0[B_hstep * 4 + 0] * scale); - pp[9] = float2int8(p0[B_hstep * 4 + 4] * scale); - pp[10] = float2int8(p0[B_hstep * 4 + 1] * scale); - pp[11] = float2int8(p0[B_hstep * 4 + 5] * scale); - pp[12] = float2int8(p0[B_hstep * 4 + 2] * scale); - pp[13] = float2int8(p0[B_hstep * 4 + 6] * scale); - pp[14] = float2int8(p0[B_hstep * 4 + 3] * scale); - pp[15] = float2int8(p0[B_hstep * 4 + 7] * scale); + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + B_hstep * 4); + __m256 _p2 = _mm256_loadu_ps(p0 + B_hstep * 8); + __m256 _p3 = _mm256_loadu_ps(p0 + B_hstep * 12); + + __m512 _p01 = combine8x2_ps(_p0, _p1); + __m512 _p23 = combine8x2_ps(_p2, _p3); + + _p01 = _mm512_mul_ps(_p01, _scale); + _p23 = _mm512_mul_ps(_p23, _scale); + + __m128i _pp0 = float2int8_avx512(_p01); + __m128i _pp1 = float2int8_avx512(_p23); + + __m128i _si = _mm_setr_epi8(0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15); + _pp0 = _mm_shuffle_epi8(_pp0, _si); + _pp1 = _mm_shuffle_epi8(_pp1, _si); - pp[16 + 0] = float2int8(p0[B_hstep * 8 + 0] * scale); - pp[16 + 1] = float2int8(p0[B_hstep * 8 + 4] * scale); - pp[16 + 2] = float2int8(p0[B_hstep * 8 + 1] * scale); - pp[16 + 3] = float2int8(p0[B_hstep * 8 + 5] * scale); - pp[16 + 4] = float2int8(p0[B_hstep * 8 + 2] * scale); - pp[16 + 5] = float2int8(p0[B_hstep * 8 + 6] * scale); - pp[16 + 6] = float2int8(p0[B_hstep * 8 + 3] * scale); - pp[16 + 7] = float2int8(p0[B_hstep * 8 + 7] * scale); - - pp[16 + 8] = float2int8(p0[B_hstep * 12 + 0] * scale); - pp[16 + 9] = float2int8(p0[B_hstep * 12 + 4] * scale); - pp[16 + 10] = float2int8(p0[B_hstep * 12 + 1] * scale); - pp[16 + 11] = float2int8(p0[B_hstep * 12 + 5] * scale); - pp[16 + 12] = float2int8(p0[B_hstep * 12 + 2] * scale); - pp[16 + 13] = float2int8(p0[B_hstep * 12 + 6] * scale); - pp[16 + 14] = float2int8(p0[B_hstep * 12 + 3] * scale); - pp[16 + 15] = float2int8(p0[B_hstep * 12 + 7] * scale); + _mm_storeu_si128((__m128i*)pp, _pp0); + _mm_storeu_si128((__m128i*)(pp + 16), _pp1); pp += 32; p0 += 8; } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); - pp[4] = float2int8(p0[B_hstep * 4] * scale); - pp[5] = float2int8(p0[B_hstep * 4 + 1] * scale); - pp[6] = float2int8(p0[B_hstep * 4 + 2] * scale); - pp[7] = float2int8(p0[B_hstep * 4 + 3] * scale); - pp[8] = float2int8(p0[B_hstep * 8] * scale); - pp[9] = float2int8(p0[B_hstep * 8 + 1] * scale); - pp[10] = float2int8(p0[B_hstep * 8 + 2] * scale); - pp[11] = float2int8(p0[B_hstep * 8 + 3] * scale); - pp[12] = float2int8(p0[B_hstep * 12] * scale); - pp[13] = float2int8(p0[B_hstep * 12 + 1] * scale); - pp[14] = float2int8(p0[B_hstep * 12 + 2] * scale); - pp[15] = float2int8(p0[B_hstep * 12 + 3] * scale); + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + B_hstep * 4); + __m128 _p2 = _mm_loadu_ps(p0 + B_hstep * 8); + __m128 _p3 = _mm_loadu_ps(p0 + B_hstep * 12); + + __m512 _p = combine4x4_ps(_p0, _p1, _p2, _p3); + _p = _mm512_mul_ps(_p, _scale); + + __m128i _pp = float2int8_avx512(_p); + + _mm_storeu_si128((__m128i*)pp, _pp); + pp += 16; p0 += 4; } @@ -7349,71 +7168,43 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[1] * scale) + 127; - pp[2] = float2int8(p0[2] * scale) + 127; - pp[3] = float2int8(p0[3] * scale) + 127; - pp[4] = float2int8(p0[B_hstep] * scale) + 127; - pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; - pp[6] = float2int8(p0[B_hstep + 2] * scale) + 127; - pp[7] = float2int8(p0[B_hstep + 3] * scale) + 127; - pp[8] = float2int8(p0[B_hstep * 2] * scale) + 127; - pp[9] = float2int8(p0[B_hstep * 2 + 1] * scale) + 127; - pp[10] = float2int8(p0[B_hstep * 2 + 2] * scale) + 127; - pp[11] = float2int8(p0[B_hstep * 2 + 3] * scale) + 127; - pp[12] = float2int8(p0[B_hstep * 3] * scale) + 127; - pp[13] = float2int8(p0[B_hstep * 3 + 1] * scale) + 127; - pp[14] = float2int8(p0[B_hstep * 3 + 2] * scale) + 127; - pp[15] = float2int8(p0[B_hstep * 3 + 3] * scale) + 127; - pp[16] = float2int8(p0[B_hstep * 4] * scale) + 127; - pp[17] = float2int8(p0[B_hstep * 4 + 1] * scale) + 127; - pp[18] = float2int8(p0[B_hstep * 4 + 2] * scale) + 127; - pp[19] = float2int8(p0[B_hstep * 4 + 3] * scale) + 127; - pp[20] = float2int8(p0[B_hstep * 5] * scale) + 127; - pp[21] = float2int8(p0[B_hstep * 5 + 1] * scale) + 127; - pp[22] = float2int8(p0[B_hstep * 5 + 2] * scale) + 127; - pp[23] = float2int8(p0[B_hstep * 5 + 3] * scale) + 127; - pp[24] = float2int8(p0[B_hstep * 6] * scale) + 127; - pp[25] = float2int8(p0[B_hstep * 6 + 1] * scale) + 127; - pp[26] = float2int8(p0[B_hstep * 6 + 2] * scale) + 127; - pp[27] = float2int8(p0[B_hstep * 6 + 3] * scale) + 127; - pp[28] = float2int8(p0[B_hstep * 7] * scale) + 127; - pp[29] = float2int8(p0[B_hstep * 7 + 1] * scale) + 127; - pp[30] = float2int8(p0[B_hstep * 7 + 2] * scale) + 127; - pp[31] = float2int8(p0[B_hstep * 7 + 3] * scale) + 127; + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + B_hstep); + __m128 _p2 = _mm_loadu_ps(p0 + B_hstep * 2); + __m128 _p3 = _mm_loadu_ps(p0 + B_hstep * 3); + __m128 _p4 = _mm_loadu_ps(p0 + B_hstep * 4); + __m128 _p5 = _mm_loadu_ps(p0 + B_hstep * 5); + __m128 _p6 = _mm_loadu_ps(p0 + B_hstep * 6); + __m128 _p7 = _mm_loadu_ps(p0 + B_hstep * 7); + __m128 _p8 = _mm_loadu_ps(p0 + B_hstep * 8); + __m128 _p9 = _mm_loadu_ps(p0 + B_hstep * 9); + __m128 _pa = _mm_loadu_ps(p0 + B_hstep * 10); + __m128 _pb = _mm_loadu_ps(p0 + B_hstep * 11); + __m128 _pc = _mm_loadu_ps(p0 + B_hstep * 12); + __m128 _pd = _mm_loadu_ps(p0 + B_hstep * 13); + __m128 _pe = _mm_loadu_ps(p0 + B_hstep * 14); + __m128 _pf = _mm_loadu_ps(p0 + B_hstep * 15); + + __m512 _t0 = combine4x4_ps(_p0, _p1, _p2, _p3); + __m512 _t1 = combine4x4_ps(_p4, _p5, _p6, _p7); + __m512 _t2 = combine4x4_ps(_p8, _p9, _pa, _pb); + __m512 _t3 = combine4x4_ps(_pc, _pd, _pe, _pf); + + _t0 = _mm512_mul_ps(_t0, _scale); + _t1 = _mm512_mul_ps(_t1, _scale); + _t2 = _mm512_mul_ps(_t2, _scale); + _t3 = _mm512_mul_ps(_t3, _scale); + + __m128i _pp0 = float2int8_avx512(_t0); + __m128i _pp1 = float2int8_avx512(_t1); + __m128i _pp2 = float2int8_avx512(_t2); + __m128i _pp3 = float2int8_avx512(_t3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); - pp[32 + 0] = float2int8(p0[B_hstep * 8] * scale) + 127; - pp[32 + 1] = float2int8(p0[B_hstep * 8 + 1] * scale) + 127; - pp[32 + 2] = float2int8(p0[B_hstep * 8 + 2] * scale) + 127; - pp[32 + 3] = float2int8(p0[B_hstep * 8 + 3] * scale) + 127; - pp[32 + 4] = float2int8(p0[B_hstep * 9] * scale) + 127; - pp[32 + 5] = float2int8(p0[B_hstep * 9 + 1] * scale) + 127; - pp[32 + 6] = float2int8(p0[B_hstep * 9 + 2] * scale) + 127; - pp[32 + 7] = float2int8(p0[B_hstep * 9 + 3] * scale) + 127; - pp[32 + 8] = float2int8(p0[B_hstep * 10] * scale) + 127; - pp[32 + 9] = float2int8(p0[B_hstep * 10 + 1] * scale) + 127; - pp[32 + 10] = float2int8(p0[B_hstep * 10 + 2] * scale) + 127; - pp[32 + 11] = float2int8(p0[B_hstep * 10 + 3] * scale) + 127; - pp[32 + 12] = float2int8(p0[B_hstep * 11] * scale) + 127; - pp[32 + 13] = float2int8(p0[B_hstep * 11 + 1] * scale) + 127; - pp[32 + 14] = float2int8(p0[B_hstep * 11 + 2] * scale) + 127; - pp[32 + 15] = float2int8(p0[B_hstep * 11 + 3] * scale) + 127; - pp[32 + 16] = float2int8(p0[B_hstep * 12] * scale) + 127; - pp[32 + 17] = float2int8(p0[B_hstep * 12 + 1] * scale) + 127; - pp[32 + 18] = float2int8(p0[B_hstep * 12 + 2] * scale) + 127; - pp[32 + 19] = float2int8(p0[B_hstep * 12 + 3] * scale) + 127; - pp[32 + 20] = float2int8(p0[B_hstep * 13] * scale) + 127; - pp[32 + 21] = float2int8(p0[B_hstep * 13 + 1] * scale) + 127; - pp[32 + 22] = float2int8(p0[B_hstep * 13 + 2] * scale) + 127; - pp[32 + 23] = float2int8(p0[B_hstep * 13 + 3] * scale) + 127; - pp[32 + 24] = float2int8(p0[B_hstep * 14] * scale) + 127; - pp[32 + 25] = float2int8(p0[B_hstep * 14 + 1] * scale) + 127; - pp[32 + 26] = float2int8(p0[B_hstep * 14 + 2] * scale) + 127; - pp[32 + 27] = float2int8(p0[B_hstep * 14 + 3] * scale) + 127; - pp[32 + 28] = float2int8(p0[B_hstep * 15] * scale) + 127; - pp[32 + 29] = float2int8(p0[B_hstep * 15 + 1] * scale) + 127; - pp[32 + 30] = float2int8(p0[B_hstep * 15 + 2] * scale) + 127; - pp[32 + 31] = float2int8(p0[B_hstep * 15 + 3] * scale) + 127; + _pp = _mm512_add_epi8(_pp, _v127); + + _mm512_storeu_si512((__m512i*)pp, _pp); pp += 64; p0 += 4; @@ -7421,60 +7212,38 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[B_hstep] * scale); - pp[3] = float2int8(p0[B_hstep + 1] * scale); - pp[4] = float2int8(p0[B_hstep * 2] * scale); - pp[5] = float2int8(p0[B_hstep * 2 + 1] * scale); - pp[6] = float2int8(p0[B_hstep * 3] * scale); - pp[7] = float2int8(p0[B_hstep * 3 + 1] * scale); - pp[8] = float2int8(p0[B_hstep * 4] * scale); - pp[9] = float2int8(p0[B_hstep * 4 + 1] * scale); - pp[10] = float2int8(p0[B_hstep * 5] * scale); - pp[11] = float2int8(p0[B_hstep * 5 + 1] * scale); - pp[12] = float2int8(p0[B_hstep * 6] * scale); - pp[13] = float2int8(p0[B_hstep * 6 + 1] * scale); - pp[14] = float2int8(p0[B_hstep * 7] * scale); - pp[15] = float2int8(p0[B_hstep * 7 + 1] * scale); + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(B_hstep)); + + __m512 _p0 = _mm512_i32gather_ps(_vindex, p0, sizeof(float)); + __m512 _p1 = _mm512_i32gather_ps(_vindex, p0 + 1, sizeof(float)); + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 16), _t1); - pp[16 + 0] = float2int8(p0[B_hstep * 8] * scale); - pp[16 + 1] = float2int8(p0[B_hstep * 8 + 1] * scale); - pp[16 + 2] = float2int8(p0[B_hstep * 9] * scale); - pp[16 + 3] = float2int8(p0[B_hstep * 9 + 1] * scale); - pp[16 + 4] = float2int8(p0[B_hstep * 10] * scale); - pp[16 + 5] = float2int8(p0[B_hstep * 10 + 1] * scale); - pp[16 + 6] = float2int8(p0[B_hstep * 11] * scale); - pp[16 + 7] = float2int8(p0[B_hstep * 11 + 1] * scale); - pp[16 + 8] = float2int8(p0[B_hstep * 12] * scale); - pp[16 + 9] = float2int8(p0[B_hstep * 12 + 1] * scale); - pp[16 + 10] = float2int8(p0[B_hstep * 13] * scale); - pp[16 + 11] = float2int8(p0[B_hstep * 13 + 1] * scale); - pp[16 + 12] = float2int8(p0[B_hstep * 14] * scale); - pp[16 + 13] = float2int8(p0[B_hstep * 14 + 1] * scale); - pp[16 + 14] = float2int8(p0[B_hstep * 15] * scale); - pp[16 + 15] = float2int8(p0[B_hstep * 15 + 1] * scale); pp += 32; p0 += 2; } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[B_hstep] * scale); - pp[2] = float2int8(p0[B_hstep * 2] * scale); - pp[3] = float2int8(p0[B_hstep * 3] * scale); - pp[4] = float2int8(p0[B_hstep * 4] * scale); - pp[5] = float2int8(p0[B_hstep * 5] * scale); - pp[6] = float2int8(p0[B_hstep * 6] * scale); - pp[7] = float2int8(p0[B_hstep * 7] * scale); - pp[8] = float2int8(p0[B_hstep * 8] * scale); - pp[9] = float2int8(p0[B_hstep * 9] * scale); - pp[10] = float2int8(p0[B_hstep * 10] * scale); - pp[11] = float2int8(p0[B_hstep * 11] * scale); - pp[12] = float2int8(p0[B_hstep * 12] * scale); - pp[13] = float2int8(p0[B_hstep * 13] * scale); - pp[14] = float2int8(p0[B_hstep * 14] * scale); - pp[15] = float2int8(p0[B_hstep * 15] * scale); + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(B_hstep)); + + __m512 _p = _mm512_i32gather_ps(_vindex, p0, sizeof(float)); + _p = _mm512_mul_ps(_p, _scale); + + __m128i _pp = float2int8_avx512(_p); + + _mm_storeu_si128((__m128i*)pp, _pp); + pp += 16; p0++; } @@ -8200,275 +7969,91 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int { const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + __m512 _scale = _mm512_set1_ps(scale); +#if __AVX512VNNI__ + __m512i _v127 = _mm512_set1_epi8(127); +#endif // __AVX512VNNI__ + if (elempack == 16) { int kk = 0; #if __AVX512VNNI__ for (; kk + 15 < max_kk; kk += 16) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[1] * scale) + 127; - pp[2] = float2int8(p0[2 + 0] * scale) + 127; - pp[3] = float2int8(p0[2 + 1] * scale) + 127; - pp[4] = float2int8(p0[16] * scale) + 127; - pp[5] = float2int8(p0[17] * scale) + 127; - pp[6] = float2int8(p0[2 + 16] * scale) + 127; - pp[7] = float2int8(p0[2 + 17] * scale) + 127; - pp[8] = float2int8(p0[32] * scale) + 127; - pp[9] = float2int8(p0[33] * scale) + 127; - pp[10] = float2int8(p0[2 + 32] * scale) + 127; - pp[11] = float2int8(p0[2 + 33] * scale) + 127; - pp[12] = float2int8(p0[48] * scale) + 127; - pp[13] = float2int8(p0[49] * scale) + 127; - pp[14] = float2int8(p0[2 + 48] * scale) + 127; - pp[15] = float2int8(p0[2 + 49] * scale) + 127; - pp[16] = float2int8(p0[64] * scale) + 127; - pp[17] = float2int8(p0[65] * scale) + 127; - pp[18] = float2int8(p0[2 + 64] * scale) + 127; - pp[19] = float2int8(p0[2 + 65] * scale) + 127; - pp[20] = float2int8(p0[80] * scale) + 127; - pp[21] = float2int8(p0[81] * scale) + 127; - pp[22] = float2int8(p0[2 + 80] * scale) + 127; - pp[23] = float2int8(p0[2 + 81] * scale) + 127; - pp[24] = float2int8(p0[96] * scale) + 127; - pp[25] = float2int8(p0[97] * scale) + 127; - pp[26] = float2int8(p0[2 + 96] * scale) + 127; - pp[27] = float2int8(p0[2 + 97] * scale) + 127; - pp[28] = float2int8(p0[112] * scale) + 127; - pp[29] = float2int8(p0[113] * scale) + 127; - pp[30] = float2int8(p0[2 + 112] * scale) + 127; - pp[31] = float2int8(p0[2 + 113] * scale) + 127; + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + __m512 _p8 = _mm512_loadu_ps(p0 + 128); + __m512 _p9 = _mm512_loadu_ps(p0 + 128 + 16); + __m512 _pa = _mm512_loadu_ps(p0 + 128 + 32); + __m512 _pb = _mm512_loadu_ps(p0 + 128 + 48); + __m512 _pc = _mm512_loadu_ps(p0 + 128 + 64); + __m512 _pd = _mm512_loadu_ps(p0 + 128 + 80); + __m512 _pe = _mm512_loadu_ps(p0 + 128 + 96); + __m512 _pf = _mm512_loadu_ps(p0 + 128 + 112); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + _p4 = _mm512_mul_ps(_p4, _scale); + _p5 = _mm512_mul_ps(_p5, _scale); + _p6 = _mm512_mul_ps(_p6, _scale); + _p7 = _mm512_mul_ps(_p7, _scale); + _p8 = _mm512_mul_ps(_p8, _scale); + _p9 = _mm512_mul_ps(_p9, _scale); + _pa = _mm512_mul_ps(_pa, _scale); + _pb = _mm512_mul_ps(_pb, _scale); + _pc = _mm512_mul_ps(_pc, _scale); + _pd = _mm512_mul_ps(_pd, _scale); + _pe = _mm512_mul_ps(_pe, _scale); + _pf = _mm512_mul_ps(_pf, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + __m128i _pp8 = float2int8_avx512(_p8); + __m128i _pp9 = float2int8_avx512(_p9); + __m128i _ppa = float2int8_avx512(_pa); + __m128i _ppb = float2int8_avx512(_pb); + __m128i _ppc = float2int8_avx512(_pc); + __m128i _ppd = float2int8_avx512(_pd); + __m128i _ppe = float2int8_avx512(_pe); + __m128i _ppf = float2int8_avx512(_pf); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp4, _pp8, _ppc); + __m512i _t1 = combine4x4_epi32(_pp1, _pp5, _pp9, _ppd); + __m512i _t2 = combine4x4_epi32(_pp2, _pp6, _ppa, _ppe); + __m512i _t3 = combine4x4_epi32(_pp3, _pp7, _ppb, _ppf); + + __m512i _t4 = _mm512_unpacklo_epi32(_t0, _t1); + __m512i _t5 = _mm512_unpackhi_epi32(_t0, _t1); + __m512i _t6 = _mm512_unpacklo_epi32(_t2, _t3); + __m512i _t7 = _mm512_unpackhi_epi32(_t2, _t3); + _t0 = _mm512_unpacklo_epi64(_t4, _t6); + _t1 = _mm512_unpackhi_epi64(_t4, _t6); + _t2 = _mm512_unpacklo_epi64(_t5, _t7); + _t3 = _mm512_unpackhi_epi64(_t5, _t7); - pp[32 + 0] = float2int8(p0[128 + 0] * scale) + 127; - pp[32 + 1] = float2int8(p0[128 + 1] * scale) + 127; - pp[32 + 2] = float2int8(p0[2 + 128 + 0] * scale) + 127; - pp[32 + 3] = float2int8(p0[2 + 128 + 1] * scale) + 127; - pp[32 + 4] = float2int8(p0[128 + 16] * scale) + 127; - pp[32 + 5] = float2int8(p0[128 + 17] * scale) + 127; - pp[32 + 6] = float2int8(p0[2 + 128 + 16] * scale) + 127; - pp[32 + 7] = float2int8(p0[2 + 128 + 17] * scale) + 127; - pp[32 + 8] = float2int8(p0[128 + 32] * scale) + 127; - pp[32 + 9] = float2int8(p0[128 + 33] * scale) + 127; - pp[32 + 10] = float2int8(p0[2 + 128 + 32] * scale) + 127; - pp[32 + 11] = float2int8(p0[2 + 128 + 33] * scale) + 127; - pp[32 + 12] = float2int8(p0[128 + 48] * scale) + 127; - pp[32 + 13] = float2int8(p0[128 + 49] * scale) + 127; - pp[32 + 14] = float2int8(p0[2 + 128 + 48] * scale) + 127; - pp[32 + 15] = float2int8(p0[2 + 128 + 49] * scale) + 127; - pp[32 + 16] = float2int8(p0[128 + 64] * scale) + 127; - pp[32 + 17] = float2int8(p0[128 + 65] * scale) + 127; - pp[32 + 18] = float2int8(p0[2 + 128 + 64] * scale) + 127; - pp[32 + 19] = float2int8(p0[2 + 128 + 65] * scale) + 127; - pp[32 + 20] = float2int8(p0[128 + 80] * scale) + 127; - pp[32 + 21] = float2int8(p0[128 + 81] * scale) + 127; - pp[32 + 22] = float2int8(p0[2 + 128 + 80] * scale) + 127; - pp[32 + 23] = float2int8(p0[2 + 128 + 81] * scale) + 127; - pp[32 + 24] = float2int8(p0[128 + 96] * scale) + 127; - pp[32 + 25] = float2int8(p0[128 + 97] * scale) + 127; - pp[32 + 26] = float2int8(p0[2 + 128 + 96] * scale) + 127; - pp[32 + 27] = float2int8(p0[2 + 128 + 97] * scale) + 127; - pp[32 + 28] = float2int8(p0[128 + 112] * scale) + 127; - pp[32 + 29] = float2int8(p0[128 + 113] * scale) + 127; - pp[32 + 30] = float2int8(p0[2 + 128 + 112] * scale) + 127; - pp[32 + 31] = float2int8(p0[2 + 128 + 113] * scale) + 127; - - pp[64 + 0] = float2int8(p0[4 + 0] * scale) + 127; - pp[64 + 1] = float2int8(p0[4 + 1] * scale) + 127; - pp[64 + 2] = float2int8(p0[6 + 0] * scale) + 127; - pp[64 + 3] = float2int8(p0[6 + 1] * scale) + 127; - pp[64 + 4] = float2int8(p0[4 + 16] * scale) + 127; - pp[64 + 5] = float2int8(p0[4 + 17] * scale) + 127; - pp[64 + 6] = float2int8(p0[6 + 16] * scale) + 127; - pp[64 + 7] = float2int8(p0[6 + 17] * scale) + 127; - pp[64 + 8] = float2int8(p0[4 + 32] * scale) + 127; - pp[64 + 9] = float2int8(p0[4 + 33] * scale) + 127; - pp[64 + 10] = float2int8(p0[6 + 32] * scale) + 127; - pp[64 + 11] = float2int8(p0[6 + 33] * scale) + 127; - pp[64 + 12] = float2int8(p0[4 + 48] * scale) + 127; - pp[64 + 13] = float2int8(p0[4 + 49] * scale) + 127; - pp[64 + 14] = float2int8(p0[6 + 48] * scale) + 127; - pp[64 + 15] = float2int8(p0[6 + 49] * scale) + 127; - pp[64 + 16] = float2int8(p0[4 + 64] * scale) + 127; - pp[64 + 17] = float2int8(p0[4 + 65] * scale) + 127; - pp[64 + 18] = float2int8(p0[6 + 64] * scale) + 127; - pp[64 + 19] = float2int8(p0[6 + 65] * scale) + 127; - pp[64 + 20] = float2int8(p0[4 + 80] * scale) + 127; - pp[64 + 21] = float2int8(p0[4 + 81] * scale) + 127; - pp[64 + 22] = float2int8(p0[6 + 80] * scale) + 127; - pp[64 + 23] = float2int8(p0[6 + 81] * scale) + 127; - pp[64 + 24] = float2int8(p0[4 + 96] * scale) + 127; - pp[64 + 25] = float2int8(p0[4 + 97] * scale) + 127; - pp[64 + 26] = float2int8(p0[6 + 96] * scale) + 127; - pp[64 + 27] = float2int8(p0[6 + 97] * scale) + 127; - pp[64 + 28] = float2int8(p0[4 + 112] * scale) + 127; - pp[64 + 29] = float2int8(p0[4 + 113] * scale) + 127; - pp[64 + 30] = float2int8(p0[6 + 112] * scale) + 127; - pp[64 + 31] = float2int8(p0[6 + 113] * scale) + 127; - - pp[96 + 0] = float2int8(p0[4 + 128 + 0] * scale) + 127; - pp[96 + 1] = float2int8(p0[4 + 128 + 1] * scale) + 127; - pp[96 + 2] = float2int8(p0[6 + 128 + 0] * scale) + 127; - pp[96 + 3] = float2int8(p0[6 + 128 + 1] * scale) + 127; - pp[96 + 4] = float2int8(p0[4 + 128 + 16] * scale) + 127; - pp[96 + 5] = float2int8(p0[4 + 128 + 17] * scale) + 127; - pp[96 + 6] = float2int8(p0[6 + 128 + 16] * scale) + 127; - pp[96 + 7] = float2int8(p0[6 + 128 + 17] * scale) + 127; - pp[96 + 8] = float2int8(p0[4 + 128 + 32] * scale) + 127; - pp[96 + 9] = float2int8(p0[4 + 128 + 33] * scale) + 127; - pp[96 + 10] = float2int8(p0[6 + 128 + 32] * scale) + 127; - pp[96 + 11] = float2int8(p0[6 + 128 + 33] * scale) + 127; - pp[96 + 12] = float2int8(p0[4 + 128 + 48] * scale) + 127; - pp[96 + 13] = float2int8(p0[4 + 128 + 49] * scale) + 127; - pp[96 + 14] = float2int8(p0[6 + 128 + 48] * scale) + 127; - pp[96 + 15] = float2int8(p0[6 + 128 + 49] * scale) + 127; - pp[96 + 16] = float2int8(p0[4 + 128 + 64] * scale) + 127; - pp[96 + 17] = float2int8(p0[4 + 128 + 65] * scale) + 127; - pp[96 + 18] = float2int8(p0[6 + 128 + 64] * scale) + 127; - pp[96 + 19] = float2int8(p0[6 + 128 + 65] * scale) + 127; - pp[96 + 20] = float2int8(p0[4 + 128 + 80] * scale) + 127; - pp[96 + 21] = float2int8(p0[4 + 128 + 81] * scale) + 127; - pp[96 + 22] = float2int8(p0[6 + 128 + 80] * scale) + 127; - pp[96 + 23] = float2int8(p0[6 + 128 + 81] * scale) + 127; - pp[96 + 24] = float2int8(p0[4 + 128 + 96] * scale) + 127; - pp[96 + 25] = float2int8(p0[4 + 128 + 97] * scale) + 127; - pp[96 + 26] = float2int8(p0[6 + 128 + 96] * scale) + 127; - pp[96 + 27] = float2int8(p0[6 + 128 + 97] * scale) + 127; - pp[96 + 28] = float2int8(p0[4 + 128 + 112] * scale) + 127; - pp[96 + 29] = float2int8(p0[4 + 128 + 113] * scale) + 127; - pp[96 + 30] = float2int8(p0[6 + 128 + 112] * scale) + 127; - pp[96 + 31] = float2int8(p0[6 + 128 + 113] * scale) + 127; - - pp[128 + 0] = float2int8(p0[8 + 0] * scale) + 127; - pp[128 + 1] = float2int8(p0[8 + 1] * scale) + 127; - pp[128 + 2] = float2int8(p0[10 + 0] * scale) + 127; - pp[128 + 3] = float2int8(p0[10 + 1] * scale) + 127; - pp[128 + 4] = float2int8(p0[8 + 16] * scale) + 127; - pp[128 + 5] = float2int8(p0[8 + 17] * scale) + 127; - pp[128 + 6] = float2int8(p0[10 + 16] * scale) + 127; - pp[128 + 7] = float2int8(p0[10 + 17] * scale) + 127; - pp[128 + 8] = float2int8(p0[8 + 32] * scale) + 127; - pp[128 + 9] = float2int8(p0[8 + 33] * scale) + 127; - pp[128 + 10] = float2int8(p0[10 + 32] * scale) + 127; - pp[128 + 11] = float2int8(p0[10 + 33] * scale) + 127; - pp[128 + 12] = float2int8(p0[8 + 48] * scale) + 127; - pp[128 + 13] = float2int8(p0[8 + 49] * scale) + 127; - pp[128 + 14] = float2int8(p0[10 + 48] * scale) + 127; - pp[128 + 15] = float2int8(p0[10 + 49] * scale) + 127; - pp[128 + 16] = float2int8(p0[8 + 64] * scale) + 127; - pp[128 + 17] = float2int8(p0[8 + 65] * scale) + 127; - pp[128 + 18] = float2int8(p0[10 + 64] * scale) + 127; - pp[128 + 19] = float2int8(p0[10 + 65] * scale) + 127; - pp[128 + 20] = float2int8(p0[8 + 80] * scale) + 127; - pp[128 + 21] = float2int8(p0[8 + 81] * scale) + 127; - pp[128 + 22] = float2int8(p0[10 + 80] * scale) + 127; - pp[128 + 23] = float2int8(p0[10 + 81] * scale) + 127; - pp[128 + 24] = float2int8(p0[8 + 96] * scale) + 127; - pp[128 + 25] = float2int8(p0[8 + 97] * scale) + 127; - pp[128 + 26] = float2int8(p0[10 + 96] * scale) + 127; - pp[128 + 27] = float2int8(p0[10 + 97] * scale) + 127; - pp[128 + 28] = float2int8(p0[8 + 112] * scale) + 127; - pp[128 + 29] = float2int8(p0[8 + 113] * scale) + 127; - pp[128 + 30] = float2int8(p0[10 + 112] * scale) + 127; - pp[128 + 31] = float2int8(p0[10 + 113] * scale) + 127; - - pp[160 + 0] = float2int8(p0[8 + 128 + 0] * scale) + 127; - pp[160 + 1] = float2int8(p0[8 + 128 + 1] * scale) + 127; - pp[160 + 2] = float2int8(p0[10 + 128 + 0] * scale) + 127; - pp[160 + 3] = float2int8(p0[10 + 128 + 1] * scale) + 127; - pp[160 + 4] = float2int8(p0[8 + 128 + 16] * scale) + 127; - pp[160 + 5] = float2int8(p0[8 + 128 + 17] * scale) + 127; - pp[160 + 6] = float2int8(p0[10 + 128 + 16] * scale) + 127; - pp[160 + 7] = float2int8(p0[10 + 128 + 17] * scale) + 127; - pp[160 + 8] = float2int8(p0[8 + 128 + 32] * scale) + 127; - pp[160 + 9] = float2int8(p0[8 + 128 + 33] * scale) + 127; - pp[160 + 10] = float2int8(p0[10 + 128 + 32] * scale) + 127; - pp[160 + 11] = float2int8(p0[10 + 128 + 33] * scale) + 127; - pp[160 + 12] = float2int8(p0[8 + 128 + 48] * scale) + 127; - pp[160 + 13] = float2int8(p0[8 + 128 + 49] * scale) + 127; - pp[160 + 14] = float2int8(p0[10 + 128 + 48] * scale) + 127; - pp[160 + 15] = float2int8(p0[10 + 128 + 49] * scale) + 127; - pp[160 + 16] = float2int8(p0[8 + 128 + 64] * scale) + 127; - pp[160 + 17] = float2int8(p0[8 + 128 + 65] * scale) + 127; - pp[160 + 18] = float2int8(p0[10 + 128 + 64] * scale) + 127; - pp[160 + 19] = float2int8(p0[10 + 128 + 65] * scale) + 127; - pp[160 + 20] = float2int8(p0[8 + 128 + 80] * scale) + 127; - pp[160 + 21] = float2int8(p0[8 + 128 + 81] * scale) + 127; - pp[160 + 22] = float2int8(p0[10 + 128 + 80] * scale) + 127; - pp[160 + 23] = float2int8(p0[10 + 128 + 81] * scale) + 127; - pp[160 + 24] = float2int8(p0[8 + 128 + 96] * scale) + 127; - pp[160 + 25] = float2int8(p0[8 + 128 + 97] * scale) + 127; - pp[160 + 26] = float2int8(p0[10 + 128 + 96] * scale) + 127; - pp[160 + 27] = float2int8(p0[10 + 128 + 97] * scale) + 127; - pp[160 + 28] = float2int8(p0[8 + 128 + 112] * scale) + 127; - pp[160 + 29] = float2int8(p0[8 + 128 + 113] * scale) + 127; - pp[160 + 30] = float2int8(p0[10 + 128 + 112] * scale) + 127; - pp[160 + 31] = float2int8(p0[10 + 128 + 113] * scale) + 127; - - pp[192 + 0] = float2int8(p0[12 + 0] * scale) + 127; - pp[192 + 1] = float2int8(p0[12 + 1] * scale) + 127; - pp[192 + 2] = float2int8(p0[14 + 0] * scale) + 127; - pp[192 + 3] = float2int8(p0[14 + 1] * scale) + 127; - pp[192 + 4] = float2int8(p0[12 + 16] * scale) + 127; - pp[192 + 5] = float2int8(p0[12 + 17] * scale) + 127; - pp[192 + 6] = float2int8(p0[14 + 16] * scale) + 127; - pp[192 + 7] = float2int8(p0[14 + 17] * scale) + 127; - pp[192 + 8] = float2int8(p0[12 + 32] * scale) + 127; - pp[192 + 9] = float2int8(p0[12 + 33] * scale) + 127; - pp[192 + 10] = float2int8(p0[14 + 32] * scale) + 127; - pp[192 + 11] = float2int8(p0[14 + 33] * scale) + 127; - pp[192 + 12] = float2int8(p0[12 + 48] * scale) + 127; - pp[192 + 13] = float2int8(p0[12 + 49] * scale) + 127; - pp[192 + 14] = float2int8(p0[14 + 48] * scale) + 127; - pp[192 + 15] = float2int8(p0[14 + 49] * scale) + 127; - pp[192 + 16] = float2int8(p0[12 + 64] * scale) + 127; - pp[192 + 17] = float2int8(p0[12 + 65] * scale) + 127; - pp[192 + 18] = float2int8(p0[14 + 64] * scale) + 127; - pp[192 + 19] = float2int8(p0[14 + 65] * scale) + 127; - pp[192 + 20] = float2int8(p0[12 + 80] * scale) + 127; - pp[192 + 21] = float2int8(p0[12 + 81] * scale) + 127; - pp[192 + 22] = float2int8(p0[14 + 80] * scale) + 127; - pp[192 + 23] = float2int8(p0[14 + 81] * scale) + 127; - pp[192 + 24] = float2int8(p0[12 + 96] * scale) + 127; - pp[192 + 25] = float2int8(p0[12 + 97] * scale) + 127; - pp[192 + 26] = float2int8(p0[14 + 96] * scale) + 127; - pp[192 + 27] = float2int8(p0[14 + 97] * scale) + 127; - pp[192 + 28] = float2int8(p0[12 + 112] * scale) + 127; - pp[192 + 29] = float2int8(p0[12 + 113] * scale) + 127; - pp[192 + 30] = float2int8(p0[14 + 112] * scale) + 127; - pp[192 + 31] = float2int8(p0[14 + 113] * scale) + 127; - - pp[224 + 0] = float2int8(p0[12 + 128 + 0] * scale) + 127; - pp[224 + 1] = float2int8(p0[12 + 128 + 1] * scale) + 127; - pp[224 + 2] = float2int8(p0[14 + 128 + 0] * scale) + 127; - pp[224 + 3] = float2int8(p0[14 + 128 + 1] * scale) + 127; - pp[224 + 4] = float2int8(p0[12 + 128 + 16] * scale) + 127; - pp[224 + 5] = float2int8(p0[12 + 128 + 17] * scale) + 127; - pp[224 + 6] = float2int8(p0[14 + 128 + 16] * scale) + 127; - pp[224 + 7] = float2int8(p0[14 + 128 + 17] * scale) + 127; - pp[224 + 8] = float2int8(p0[12 + 128 + 32] * scale) + 127; - pp[224 + 9] = float2int8(p0[12 + 128 + 33] * scale) + 127; - pp[224 + 10] = float2int8(p0[14 + 128 + 32] * scale) + 127; - pp[224 + 11] = float2int8(p0[14 + 128 + 33] * scale) + 127; - pp[224 + 12] = float2int8(p0[12 + 128 + 48] * scale) + 127; - pp[224 + 13] = float2int8(p0[12 + 128 + 49] * scale) + 127; - pp[224 + 14] = float2int8(p0[14 + 128 + 48] * scale) + 127; - pp[224 + 15] = float2int8(p0[14 + 128 + 49] * scale) + 127; - pp[224 + 16] = float2int8(p0[12 + 128 + 64] * scale) + 127; - pp[224 + 17] = float2int8(p0[12 + 128 + 65] * scale) + 127; - pp[224 + 18] = float2int8(p0[14 + 128 + 64] * scale) + 127; - pp[224 + 19] = float2int8(p0[14 + 128 + 65] * scale) + 127; - pp[224 + 20] = float2int8(p0[12 + 128 + 80] * scale) + 127; - pp[224 + 21] = float2int8(p0[12 + 128 + 81] * scale) + 127; - pp[224 + 22] = float2int8(p0[14 + 128 + 80] * scale) + 127; - pp[224 + 23] = float2int8(p0[14 + 128 + 81] * scale) + 127; - pp[224 + 24] = float2int8(p0[12 + 128 + 96] * scale) + 127; - pp[224 + 25] = float2int8(p0[12 + 128 + 97] * scale) + 127; - pp[224 + 26] = float2int8(p0[14 + 128 + 96] * scale) + 127; - pp[224 + 27] = float2int8(p0[14 + 128 + 97] * scale) + 127; - pp[224 + 28] = float2int8(p0[12 + 128 + 112] * scale) + 127; - pp[224 + 29] = float2int8(p0[12 + 128 + 113] * scale) + 127; - pp[224 + 30] = float2int8(p0[14 + 128 + 112] * scale) + 127; - pp[224 + 31] = float2int8(p0[14 + 128 + 113] * scale) + 127; + _t0 = _mm512_add_epi8(_t0, _v127); + _t1 = _mm512_add_epi8(_t1, _v127); + _t2 = _mm512_add_epi8(_t2, _v127); + _t3 = _mm512_add_epi8(_t3, _v127); + + _mm512_storeu_si512((__m512i*)pp, _t0); + _mm512_storeu_si512((__m512i*)(pp + 64), _t1); + _mm512_storeu_si512((__m512i*)(pp + 128), _t2); + _mm512_storeu_si512((__m512i*)(pp + 192), _t3); pp += 256; p0 += B_hstep * 16; @@ -8476,420 +8061,137 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #else // __AVX512VNNI__ for (; kk + 15 < max_kk; kk += 16) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[16] * scale); - pp[3] = float2int8(p0[17] * scale); - pp[4] = float2int8(p0[32] * scale); - pp[5] = float2int8(p0[33] * scale); - pp[6] = float2int8(p0[48] * scale); - pp[7] = float2int8(p0[49] * scale); - pp[8] = float2int8(p0[64] * scale); - pp[9] = float2int8(p0[65] * scale); - pp[10] = float2int8(p0[80] * scale); - pp[11] = float2int8(p0[81] * scale); - pp[12] = float2int8(p0[96] * scale); - pp[13] = float2int8(p0[97] * scale); - pp[14] = float2int8(p0[112] * scale); - pp[15] = float2int8(p0[113] * scale); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + __m512 _p8 = _mm512_loadu_ps(p0 + 128); + __m512 _p9 = _mm512_loadu_ps(p0 + 128 + 16); + __m512 _pa = _mm512_loadu_ps(p0 + 128 + 32); + __m512 _pb = _mm512_loadu_ps(p0 + 128 + 48); + __m512 _pc = _mm512_loadu_ps(p0 + 128 + 64); + __m512 _pd = _mm512_loadu_ps(p0 + 128 + 80); + __m512 _pe = _mm512_loadu_ps(p0 + 128 + 96); + __m512 _pf = _mm512_loadu_ps(p0 + 128 + 112); - pp[16 + 0] = float2int8(p0[128 + 0] * scale); - pp[16 + 1] = float2int8(p0[128 + 1] * scale); - pp[16 + 2] = float2int8(p0[128 + 16] * scale); - pp[16 + 3] = float2int8(p0[128 + 17] * scale); - pp[16 + 4] = float2int8(p0[128 + 32] * scale); - pp[16 + 5] = float2int8(p0[128 + 33] * scale); - pp[16 + 6] = float2int8(p0[128 + 48] * scale); - pp[16 + 7] = float2int8(p0[128 + 49] * scale); - pp[16 + 8] = float2int8(p0[128 + 64] * scale); - pp[16 + 9] = float2int8(p0[128 + 65] * scale); - pp[16 + 10] = float2int8(p0[128 + 80] * scale); - pp[16 + 11] = float2int8(p0[128 + 81] * scale); - pp[16 + 12] = float2int8(p0[128 + 96] * scale); - pp[16 + 13] = float2int8(p0[128 + 97] * scale); - pp[16 + 14] = float2int8(p0[128 + 112] * scale); - pp[16 + 15] = float2int8(p0[128 + 113] * scale); - - pp[32 + 0] = float2int8(p0[2 + 0] * scale); - pp[32 + 1] = float2int8(p0[2 + 1] * scale); - pp[32 + 2] = float2int8(p0[2 + 16] * scale); - pp[32 + 3] = float2int8(p0[2 + 17] * scale); - pp[32 + 4] = float2int8(p0[2 + 32] * scale); - pp[32 + 5] = float2int8(p0[2 + 33] * scale); - pp[32 + 6] = float2int8(p0[2 + 48] * scale); - pp[32 + 7] = float2int8(p0[2 + 49] * scale); - pp[32 + 8] = float2int8(p0[2 + 64] * scale); - pp[32 + 9] = float2int8(p0[2 + 65] * scale); - pp[32 + 10] = float2int8(p0[2 + 80] * scale); - pp[32 + 11] = float2int8(p0[2 + 81] * scale); - pp[32 + 12] = float2int8(p0[2 + 96] * scale); - pp[32 + 13] = float2int8(p0[2 + 97] * scale); - pp[32 + 14] = float2int8(p0[2 + 112] * scale); - pp[32 + 15] = float2int8(p0[2 + 113] * scale); - - pp[48 + 0] = float2int8(p0[2 + 128 + 0] * scale); - pp[48 + 1] = float2int8(p0[2 + 128 + 1] * scale); - pp[48 + 2] = float2int8(p0[2 + 128 + 16] * scale); - pp[48 + 3] = float2int8(p0[2 + 128 + 17] * scale); - pp[48 + 4] = float2int8(p0[2 + 128 + 32] * scale); - pp[48 + 5] = float2int8(p0[2 + 128 + 33] * scale); - pp[48 + 6] = float2int8(p0[2 + 128 + 48] * scale); - pp[48 + 7] = float2int8(p0[2 + 128 + 49] * scale); - pp[48 + 8] = float2int8(p0[2 + 128 + 64] * scale); - pp[48 + 9] = float2int8(p0[2 + 128 + 65] * scale); - pp[48 + 10] = float2int8(p0[2 + 128 + 80] * scale); - pp[48 + 11] = float2int8(p0[2 + 128 + 81] * scale); - pp[48 + 12] = float2int8(p0[2 + 128 + 96] * scale); - pp[48 + 13] = float2int8(p0[2 + 128 + 97] * scale); - pp[48 + 14] = float2int8(p0[2 + 128 + 112] * scale); - pp[48 + 15] = float2int8(p0[2 + 128 + 113] * scale); - - pp[64 + 0] = float2int8(p0[4 + 0] * scale); - pp[64 + 1] = float2int8(p0[4 + 1] * scale); - pp[64 + 2] = float2int8(p0[4 + 16] * scale); - pp[64 + 3] = float2int8(p0[4 + 17] * scale); - pp[64 + 4] = float2int8(p0[4 + 32] * scale); - pp[64 + 5] = float2int8(p0[4 + 33] * scale); - pp[64 + 6] = float2int8(p0[4 + 48] * scale); - pp[64 + 7] = float2int8(p0[4 + 49] * scale); - pp[64 + 8] = float2int8(p0[4 + 64] * scale); - pp[64 + 9] = float2int8(p0[4 + 65] * scale); - pp[64 + 10] = float2int8(p0[4 + 80] * scale); - pp[64 + 11] = float2int8(p0[4 + 81] * scale); - pp[64 + 12] = float2int8(p0[4 + 96] * scale); - pp[64 + 13] = float2int8(p0[4 + 97] * scale); - pp[64 + 14] = float2int8(p0[4 + 112] * scale); - pp[64 + 15] = float2int8(p0[4 + 113] * scale); - - pp[80 + 0] = float2int8(p0[4 + 128 + 0] * scale); - pp[80 + 1] = float2int8(p0[4 + 128 + 1] * scale); - pp[80 + 2] = float2int8(p0[4 + 128 + 16] * scale); - pp[80 + 3] = float2int8(p0[4 + 128 + 17] * scale); - pp[80 + 4] = float2int8(p0[4 + 128 + 32] * scale); - pp[80 + 5] = float2int8(p0[4 + 128 + 33] * scale); - pp[80 + 6] = float2int8(p0[4 + 128 + 48] * scale); - pp[80 + 7] = float2int8(p0[4 + 128 + 49] * scale); - pp[80 + 8] = float2int8(p0[4 + 128 + 64] * scale); - pp[80 + 9] = float2int8(p0[4 + 128 + 65] * scale); - pp[80 + 10] = float2int8(p0[4 + 128 + 80] * scale); - pp[80 + 11] = float2int8(p0[4 + 128 + 81] * scale); - pp[80 + 12] = float2int8(p0[4 + 128 + 96] * scale); - pp[80 + 13] = float2int8(p0[4 + 128 + 97] * scale); - pp[80 + 14] = float2int8(p0[4 + 128 + 112] * scale); - pp[80 + 15] = float2int8(p0[4 + 128 + 113] * scale); - - pp[96 + 0] = float2int8(p0[6 + 0] * scale); - pp[96 + 1] = float2int8(p0[6 + 1] * scale); - pp[96 + 2] = float2int8(p0[6 + 16] * scale); - pp[96 + 3] = float2int8(p0[6 + 17] * scale); - pp[96 + 4] = float2int8(p0[6 + 32] * scale); - pp[96 + 5] = float2int8(p0[6 + 33] * scale); - pp[96 + 6] = float2int8(p0[6 + 48] * scale); - pp[96 + 7] = float2int8(p0[6 + 49] * scale); - pp[96 + 8] = float2int8(p0[6 + 64] * scale); - pp[96 + 9] = float2int8(p0[6 + 65] * scale); - pp[96 + 10] = float2int8(p0[6 + 80] * scale); - pp[96 + 11] = float2int8(p0[6 + 81] * scale); - pp[96 + 12] = float2int8(p0[6 + 96] * scale); - pp[96 + 13] = float2int8(p0[6 + 97] * scale); - pp[96 + 14] = float2int8(p0[6 + 112] * scale); - pp[96 + 15] = float2int8(p0[6 + 113] * scale); - - pp[112 + 0] = float2int8(p0[6 + 128 + 0] * scale); - pp[112 + 1] = float2int8(p0[6 + 128 + 1] * scale); - pp[112 + 2] = float2int8(p0[6 + 128 + 16] * scale); - pp[112 + 3] = float2int8(p0[6 + 128 + 17] * scale); - pp[112 + 4] = float2int8(p0[6 + 128 + 32] * scale); - pp[112 + 5] = float2int8(p0[6 + 128 + 33] * scale); - pp[112 + 6] = float2int8(p0[6 + 128 + 48] * scale); - pp[112 + 7] = float2int8(p0[6 + 128 + 49] * scale); - pp[112 + 8] = float2int8(p0[6 + 128 + 64] * scale); - pp[112 + 9] = float2int8(p0[6 + 128 + 65] * scale); - pp[112 + 10] = float2int8(p0[6 + 128 + 80] * scale); - pp[112 + 11] = float2int8(p0[6 + 128 + 81] * scale); - pp[112 + 12] = float2int8(p0[6 + 128 + 96] * scale); - pp[112 + 13] = float2int8(p0[6 + 128 + 97] * scale); - pp[112 + 14] = float2int8(p0[6 + 128 + 112] * scale); - pp[112 + 15] = float2int8(p0[6 + 128 + 113] * scale); - - pp[128 + 0] = float2int8(p0[8 + 0] * scale); - pp[128 + 1] = float2int8(p0[8 + 1] * scale); - pp[128 + 2] = float2int8(p0[8 + 16] * scale); - pp[128 + 3] = float2int8(p0[8 + 17] * scale); - pp[128 + 4] = float2int8(p0[8 + 32] * scale); - pp[128 + 5] = float2int8(p0[8 + 33] * scale); - pp[128 + 6] = float2int8(p0[8 + 48] * scale); - pp[128 + 7] = float2int8(p0[8 + 49] * scale); - pp[128 + 8] = float2int8(p0[8 + 64] * scale); - pp[128 + 9] = float2int8(p0[8 + 65] * scale); - pp[128 + 10] = float2int8(p0[8 + 80] * scale); - pp[128 + 11] = float2int8(p0[8 + 81] * scale); - pp[128 + 12] = float2int8(p0[8 + 96] * scale); - pp[128 + 13] = float2int8(p0[8 + 97] * scale); - pp[128 + 14] = float2int8(p0[8 + 112] * scale); - pp[128 + 15] = float2int8(p0[8 + 113] * scale); - - pp[16 + 128 + 0] = float2int8(p0[8 + 128 + 0] * scale); - pp[16 + 128 + 1] = float2int8(p0[8 + 128 + 1] * scale); - pp[16 + 128 + 2] = float2int8(p0[8 + 128 + 16] * scale); - pp[16 + 128 + 3] = float2int8(p0[8 + 128 + 17] * scale); - pp[16 + 128 + 4] = float2int8(p0[8 + 128 + 32] * scale); - pp[16 + 128 + 5] = float2int8(p0[8 + 128 + 33] * scale); - pp[16 + 128 + 6] = float2int8(p0[8 + 128 + 48] * scale); - pp[16 + 128 + 7] = float2int8(p0[8 + 128 + 49] * scale); - pp[16 + 128 + 8] = float2int8(p0[8 + 128 + 64] * scale); - pp[16 + 128 + 9] = float2int8(p0[8 + 128 + 65] * scale); - pp[16 + 128 + 10] = float2int8(p0[8 + 128 + 80] * scale); - pp[16 + 128 + 11] = float2int8(p0[8 + 128 + 81] * scale); - pp[16 + 128 + 12] = float2int8(p0[8 + 128 + 96] * scale); - pp[16 + 128 + 13] = float2int8(p0[8 + 128 + 97] * scale); - pp[16 + 128 + 14] = float2int8(p0[8 + 128 + 112] * scale); - pp[16 + 128 + 15] = float2int8(p0[8 + 128 + 113] * scale); - - pp[32 + 128 + 0] = float2int8(p0[10 + 0] * scale); - pp[32 + 128 + 1] = float2int8(p0[10 + 1] * scale); - pp[32 + 128 + 2] = float2int8(p0[10 + 16] * scale); - pp[32 + 128 + 3] = float2int8(p0[10 + 17] * scale); - pp[32 + 128 + 4] = float2int8(p0[10 + 32] * scale); - pp[32 + 128 + 5] = float2int8(p0[10 + 33] * scale); - pp[32 + 128 + 6] = float2int8(p0[10 + 48] * scale); - pp[32 + 128 + 7] = float2int8(p0[10 + 49] * scale); - pp[32 + 128 + 8] = float2int8(p0[10 + 64] * scale); - pp[32 + 128 + 9] = float2int8(p0[10 + 65] * scale); - pp[32 + 128 + 10] = float2int8(p0[10 + 80] * scale); - pp[32 + 128 + 11] = float2int8(p0[10 + 81] * scale); - pp[32 + 128 + 12] = float2int8(p0[10 + 96] * scale); - pp[32 + 128 + 13] = float2int8(p0[10 + 97] * scale); - pp[32 + 128 + 14] = float2int8(p0[10 + 112] * scale); - pp[32 + 128 + 15] = float2int8(p0[10 + 113] * scale); - - pp[48 + 128 + 0] = float2int8(p0[10 + 128 + 0] * scale); - pp[48 + 128 + 1] = float2int8(p0[10 + 128 + 1] * scale); - pp[48 + 128 + 2] = float2int8(p0[10 + 128 + 16] * scale); - pp[48 + 128 + 3] = float2int8(p0[10 + 128 + 17] * scale); - pp[48 + 128 + 4] = float2int8(p0[10 + 128 + 32] * scale); - pp[48 + 128 + 5] = float2int8(p0[10 + 128 + 33] * scale); - pp[48 + 128 + 6] = float2int8(p0[10 + 128 + 48] * scale); - pp[48 + 128 + 7] = float2int8(p0[10 + 128 + 49] * scale); - pp[48 + 128 + 8] = float2int8(p0[10 + 128 + 64] * scale); - pp[48 + 128 + 9] = float2int8(p0[10 + 128 + 65] * scale); - pp[48 + 128 + 10] = float2int8(p0[10 + 128 + 80] * scale); - pp[48 + 128 + 11] = float2int8(p0[10 + 128 + 81] * scale); - pp[48 + 128 + 12] = float2int8(p0[10 + 128 + 96] * scale); - pp[48 + 128 + 13] = float2int8(p0[10 + 128 + 97] * scale); - pp[48 + 128 + 14] = float2int8(p0[10 + 128 + 112] * scale); - pp[48 + 128 + 15] = float2int8(p0[10 + 128 + 113] * scale); - - pp[64 + 128 + 0] = float2int8(p0[12 + 0] * scale); - pp[64 + 128 + 1] = float2int8(p0[12 + 1] * scale); - pp[64 + 128 + 2] = float2int8(p0[12 + 16] * scale); - pp[64 + 128 + 3] = float2int8(p0[12 + 17] * scale); - pp[64 + 128 + 4] = float2int8(p0[12 + 32] * scale); - pp[64 + 128 + 5] = float2int8(p0[12 + 33] * scale); - pp[64 + 128 + 6] = float2int8(p0[12 + 48] * scale); - pp[64 + 128 + 7] = float2int8(p0[12 + 49] * scale); - pp[64 + 128 + 8] = float2int8(p0[12 + 64] * scale); - pp[64 + 128 + 9] = float2int8(p0[12 + 65] * scale); - pp[64 + 128 + 10] = float2int8(p0[12 + 80] * scale); - pp[64 + 128 + 11] = float2int8(p0[12 + 81] * scale); - pp[64 + 128 + 12] = float2int8(p0[12 + 96] * scale); - pp[64 + 128 + 13] = float2int8(p0[12 + 97] * scale); - pp[64 + 128 + 14] = float2int8(p0[12 + 112] * scale); - pp[64 + 128 + 15] = float2int8(p0[12 + 113] * scale); - - pp[80 + 128 + 0] = float2int8(p0[12 + 128 + 0] * scale); - pp[80 + 128 + 1] = float2int8(p0[12 + 128 + 1] * scale); - pp[80 + 128 + 2] = float2int8(p0[12 + 128 + 16] * scale); - pp[80 + 128 + 3] = float2int8(p0[12 + 128 + 17] * scale); - pp[80 + 128 + 4] = float2int8(p0[12 + 128 + 32] * scale); - pp[80 + 128 + 5] = float2int8(p0[12 + 128 + 33] * scale); - pp[80 + 128 + 6] = float2int8(p0[12 + 128 + 48] * scale); - pp[80 + 128 + 7] = float2int8(p0[12 + 128 + 49] * scale); - pp[80 + 128 + 8] = float2int8(p0[12 + 128 + 64] * scale); - pp[80 + 128 + 9] = float2int8(p0[12 + 128 + 65] * scale); - pp[80 + 128 + 10] = float2int8(p0[12 + 128 + 80] * scale); - pp[80 + 128 + 11] = float2int8(p0[12 + 128 + 81] * scale); - pp[80 + 128 + 12] = float2int8(p0[12 + 128 + 96] * scale); - pp[80 + 128 + 13] = float2int8(p0[12 + 128 + 97] * scale); - pp[80 + 128 + 14] = float2int8(p0[12 + 128 + 112] * scale); - pp[80 + 128 + 15] = float2int8(p0[12 + 128 + 113] * scale); - - pp[96 + 128 + 0] = float2int8(p0[14 + 0] * scale); - pp[96 + 128 + 1] = float2int8(p0[14 + 1] * scale); - pp[96 + 128 + 2] = float2int8(p0[14 + 16] * scale); - pp[96 + 128 + 3] = float2int8(p0[14 + 17] * scale); - pp[96 + 128 + 4] = float2int8(p0[14 + 32] * scale); - pp[96 + 128 + 5] = float2int8(p0[14 + 33] * scale); - pp[96 + 128 + 6] = float2int8(p0[14 + 48] * scale); - pp[96 + 128 + 7] = float2int8(p0[14 + 49] * scale); - pp[96 + 128 + 8] = float2int8(p0[14 + 64] * scale); - pp[96 + 128 + 9] = float2int8(p0[14 + 65] * scale); - pp[96 + 128 + 10] = float2int8(p0[14 + 80] * scale); - pp[96 + 128 + 11] = float2int8(p0[14 + 81] * scale); - pp[96 + 128 + 12] = float2int8(p0[14 + 96] * scale); - pp[96 + 128 + 13] = float2int8(p0[14 + 97] * scale); - pp[96 + 128 + 14] = float2int8(p0[14 + 112] * scale); - pp[96 + 128 + 15] = float2int8(p0[14 + 113] * scale); - - pp[112 + 128 + 0] = float2int8(p0[14 + 128 + 0] * scale); - pp[112 + 128 + 1] = float2int8(p0[14 + 128 + 1] * scale); - pp[112 + 128 + 2] = float2int8(p0[14 + 128 + 16] * scale); - pp[112 + 128 + 3] = float2int8(p0[14 + 128 + 17] * scale); - pp[112 + 128 + 4] = float2int8(p0[14 + 128 + 32] * scale); - pp[112 + 128 + 5] = float2int8(p0[14 + 128 + 33] * scale); - pp[112 + 128 + 6] = float2int8(p0[14 + 128 + 48] * scale); - pp[112 + 128 + 7] = float2int8(p0[14 + 128 + 49] * scale); - pp[112 + 128 + 8] = float2int8(p0[14 + 128 + 64] * scale); - pp[112 + 128 + 9] = float2int8(p0[14 + 128 + 65] * scale); - pp[112 + 128 + 10] = float2int8(p0[14 + 128 + 80] * scale); - pp[112 + 128 + 11] = float2int8(p0[14 + 128 + 81] * scale); - pp[112 + 128 + 12] = float2int8(p0[14 + 128 + 96] * scale); - pp[112 + 128 + 13] = float2int8(p0[14 + 128 + 97] * scale); - pp[112 + 128 + 14] = float2int8(p0[14 + 128 + 112] * scale); - pp[112 + 128 + 15] = float2int8(p0[14 + 128 + 113] * scale); + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + _p4 = _mm512_mul_ps(_p4, _scale); + _p5 = _mm512_mul_ps(_p5, _scale); + _p6 = _mm512_mul_ps(_p6, _scale); + _p7 = _mm512_mul_ps(_p7, _scale); + _p8 = _mm512_mul_ps(_p8, _scale); + _p9 = _mm512_mul_ps(_p9, _scale); + _pa = _mm512_mul_ps(_pa, _scale); + _pb = _mm512_mul_ps(_pb, _scale); + _pc = _mm512_mul_ps(_pc, _scale); + _pd = _mm512_mul_ps(_pd, _scale); + _pe = _mm512_mul_ps(_pe, _scale); + _pf = _mm512_mul_ps(_pf, _scale); - pp += 256; - p0 += B_hstep * 16; - } -#endif // __AVX512VNNI__ - } - if (elempack == 8) - { - int kk = 0; -#if __AVX512VNNI__ - for (; kk + 7 < max_kk; kk += 8) - { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[1] * scale) + 127; - pp[2] = float2int8(p0[2] * scale) + 127; - pp[3] = float2int8(p0[3] * scale) + 127; - pp[4] = float2int8(p0[8] * scale) + 127; - pp[5] = float2int8(p0[9] * scale) + 127; - pp[6] = float2int8(p0[10] * scale) + 127; - pp[7] = float2int8(p0[11] * scale) + 127; - pp[8] = float2int8(p0[16] * scale) + 127; - pp[9] = float2int8(p0[17] * scale) + 127; - pp[10] = float2int8(p0[18] * scale) + 127; - pp[11] = float2int8(p0[19] * scale) + 127; - pp[12] = float2int8(p0[24] * scale) + 127; - pp[13] = float2int8(p0[25] * scale) + 127; - pp[14] = float2int8(p0[26] * scale) + 127; - pp[15] = float2int8(p0[27] * scale) + 127; - pp[16] = float2int8(p0[32] * scale) + 127; - pp[17] = float2int8(p0[33] * scale) + 127; - pp[18] = float2int8(p0[34] * scale) + 127; - pp[19] = float2int8(p0[35] * scale) + 127; - pp[20] = float2int8(p0[40] * scale) + 127; - pp[21] = float2int8(p0[41] * scale) + 127; - pp[22] = float2int8(p0[42] * scale) + 127; - pp[23] = float2int8(p0[43] * scale) + 127; - pp[24] = float2int8(p0[48] * scale) + 127; - pp[25] = float2int8(p0[49] * scale) + 127; - pp[26] = float2int8(p0[50] * scale) + 127; - pp[27] = float2int8(p0[51] * scale) + 127; - pp[28] = float2int8(p0[56] * scale) + 127; - pp[29] = float2int8(p0[57] * scale) + 127; - pp[30] = float2int8(p0[58] * scale) + 127; - pp[31] = float2int8(p0[59] * scale) + 127; + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + __m128i _pp8 = float2int8_avx512(_p8); + __m128i _pp9 = float2int8_avx512(_p9); + __m128i _ppa = float2int8_avx512(_pa); + __m128i _ppb = float2int8_avx512(_pb); + __m128i _ppc = float2int8_avx512(_pc); + __m128i _ppd = float2int8_avx512(_pd); + __m128i _ppe = float2int8_avx512(_pe); + __m128i _ppf = float2int8_avx512(_pf); - pp[32 + 0] = float2int8(p0[64 + 0] * scale) + 127; - pp[32 + 1] = float2int8(p0[64 + 1] * scale) + 127; - pp[32 + 2] = float2int8(p0[64 + 2] * scale) + 127; - pp[32 + 3] = float2int8(p0[64 + 3] * scale) + 127; - pp[32 + 4] = float2int8(p0[64 + 8] * scale) + 127; - pp[32 + 5] = float2int8(p0[64 + 9] * scale) + 127; - pp[32 + 6] = float2int8(p0[64 + 10] * scale) + 127; - pp[32 + 7] = float2int8(p0[64 + 11] * scale) + 127; - pp[32 + 8] = float2int8(p0[64 + 16] * scale) + 127; - pp[32 + 9] = float2int8(p0[64 + 17] * scale) + 127; - pp[32 + 10] = float2int8(p0[64 + 18] * scale) + 127; - pp[32 + 11] = float2int8(p0[64 + 19] * scale) + 127; - pp[32 + 12] = float2int8(p0[64 + 24] * scale) + 127; - pp[32 + 13] = float2int8(p0[64 + 25] * scale) + 127; - pp[32 + 14] = float2int8(p0[64 + 26] * scale) + 127; - pp[32 + 15] = float2int8(p0[64 + 27] * scale) + 127; - pp[32 + 16] = float2int8(p0[64 + 32] * scale) + 127; - pp[32 + 17] = float2int8(p0[64 + 33] * scale) + 127; - pp[32 + 18] = float2int8(p0[64 + 34] * scale) + 127; - pp[32 + 19] = float2int8(p0[64 + 35] * scale) + 127; - pp[32 + 20] = float2int8(p0[64 + 40] * scale) + 127; - pp[32 + 21] = float2int8(p0[64 + 41] * scale) + 127; - pp[32 + 22] = float2int8(p0[64 + 42] * scale) + 127; - pp[32 + 23] = float2int8(p0[64 + 43] * scale) + 127; - pp[32 + 24] = float2int8(p0[64 + 48] * scale) + 127; - pp[32 + 25] = float2int8(p0[64 + 49] * scale) + 127; - pp[32 + 26] = float2int8(p0[64 + 50] * scale) + 127; - pp[32 + 27] = float2int8(p0[64 + 51] * scale) + 127; - pp[32 + 28] = float2int8(p0[64 + 56] * scale) + 127; - pp[32 + 29] = float2int8(p0[64 + 57] * scale) + 127; - pp[32 + 30] = float2int8(p0[64 + 58] * scale) + 127; - pp[32 + 31] = float2int8(p0[64 + 59] * scale) + 127; - - pp[64 + 0] = float2int8(p0[4] * scale) + 127; - pp[64 + 1] = float2int8(p0[5] * scale) + 127; - pp[64 + 2] = float2int8(p0[6] * scale) + 127; - pp[64 + 3] = float2int8(p0[7] * scale) + 127; - pp[64 + 4] = float2int8(p0[12] * scale) + 127; - pp[64 + 5] = float2int8(p0[13] * scale) + 127; - pp[64 + 6] = float2int8(p0[14] * scale) + 127; - pp[64 + 7] = float2int8(p0[15] * scale) + 127; - pp[64 + 8] = float2int8(p0[20] * scale) + 127; - pp[64 + 9] = float2int8(p0[21] * scale) + 127; - pp[64 + 10] = float2int8(p0[22] * scale) + 127; - pp[64 + 11] = float2int8(p0[23] * scale) + 127; - pp[64 + 12] = float2int8(p0[28] * scale) + 127; - pp[64 + 13] = float2int8(p0[29] * scale) + 127; - pp[64 + 14] = float2int8(p0[30] * scale) + 127; - pp[64 + 15] = float2int8(p0[31] * scale) + 127; - pp[64 + 16] = float2int8(p0[36] * scale) + 127; - pp[64 + 17] = float2int8(p0[37] * scale) + 127; - pp[64 + 18] = float2int8(p0[38] * scale) + 127; - pp[64 + 19] = float2int8(p0[39] * scale) + 127; - pp[64 + 20] = float2int8(p0[44] * scale) + 127; - pp[64 + 21] = float2int8(p0[45] * scale) + 127; - pp[64 + 22] = float2int8(p0[46] * scale) + 127; - pp[64 + 23] = float2int8(p0[47] * scale) + 127; - pp[64 + 24] = float2int8(p0[52] * scale) + 127; - pp[64 + 25] = float2int8(p0[53] * scale) + 127; - pp[64 + 26] = float2int8(p0[54] * scale) + 127; - pp[64 + 27] = float2int8(p0[55] * scale) + 127; - pp[64 + 28] = float2int8(p0[60] * scale) + 127; - pp[64 + 29] = float2int8(p0[61] * scale) + 127; - pp[64 + 30] = float2int8(p0[62] * scale) + 127; - pp[64 + 31] = float2int8(p0[63] * scale) + 127; - - pp[96 + 0] = float2int8(p0[64 + 4] * scale) + 127; - pp[96 + 1] = float2int8(p0[64 + 5] * scale) + 127; - pp[96 + 2] = float2int8(p0[64 + 6] * scale) + 127; - pp[96 + 3] = float2int8(p0[64 + 7] * scale) + 127; - pp[96 + 4] = float2int8(p0[64 + 12] * scale) + 127; - pp[96 + 5] = float2int8(p0[64 + 13] * scale) + 127; - pp[96 + 6] = float2int8(p0[64 + 14] * scale) + 127; - pp[96 + 7] = float2int8(p0[64 + 15] * scale) + 127; - pp[96 + 8] = float2int8(p0[64 + 20] * scale) + 127; - pp[96 + 9] = float2int8(p0[64 + 21] * scale) + 127; - pp[96 + 10] = float2int8(p0[64 + 22] * scale) + 127; - pp[96 + 11] = float2int8(p0[64 + 23] * scale) + 127; - pp[96 + 12] = float2int8(p0[64 + 28] * scale) + 127; - pp[96 + 13] = float2int8(p0[64 + 29] * scale) + 127; - pp[96 + 14] = float2int8(p0[64 + 30] * scale) + 127; - pp[96 + 15] = float2int8(p0[64 + 31] * scale) + 127; - pp[96 + 16] = float2int8(p0[64 + 36] * scale) + 127; - pp[96 + 17] = float2int8(p0[64 + 37] * scale) + 127; - pp[96 + 18] = float2int8(p0[64 + 38] * scale) + 127; - pp[96 + 19] = float2int8(p0[64 + 39] * scale) + 127; - pp[96 + 20] = float2int8(p0[64 + 44] * scale) + 127; - pp[96 + 21] = float2int8(p0[64 + 45] * scale) + 127; - pp[96 + 22] = float2int8(p0[64 + 46] * scale) + 127; - pp[96 + 23] = float2int8(p0[64 + 47] * scale) + 127; - pp[96 + 24] = float2int8(p0[64 + 52] * scale) + 127; - pp[96 + 25] = float2int8(p0[64 + 53] * scale) + 127; - pp[96 + 26] = float2int8(p0[64 + 54] * scale) + 127; - pp[96 + 27] = float2int8(p0[64 + 55] * scale) + 127; - pp[96 + 28] = float2int8(p0[64 + 60] * scale) + 127; - pp[96 + 29] = float2int8(p0[64 + 61] * scale) + 127; - pp[96 + 30] = float2int8(p0[64 + 62] * scale) + 127; - pp[96 + 31] = float2int8(p0[64 + 63] * scale) + 127; + __m512i _t0 = combine4x4_epi32(_pp0, _pp4, _pp8, _ppc); + __m512i _t1 = combine4x4_epi32(_pp1, _pp5, _pp9, _ppd); + __m512i _t2 = combine4x4_epi32(_pp2, _pp6, _ppa, _ppe); + __m512i _t3 = combine4x4_epi32(_pp3, _pp7, _ppb, _ppf); + + __m512i _t4 = _mm512_unpacklo_epi16(_t0, _t1); + __m512i _t5 = _mm512_unpackhi_epi16(_t0, _t1); + __m512i _t6 = _mm512_unpacklo_epi16(_t2, _t3); + __m512i _t7 = _mm512_unpackhi_epi16(_t2, _t3); + + _t0 = _mm512_unpacklo_epi32(_t4, _t6); + _t1 = _mm512_unpackhi_epi32(_t4, _t6); + _t2 = _mm512_unpacklo_epi32(_t5, _t7); + _t3 = _mm512_unpackhi_epi32(_t5, _t7); + + _t0 = _mm512_permutex_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm512_permutex_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0)); + _t2 = _mm512_permutex_epi64(_t2, _MM_SHUFFLE(3, 1, 2, 0)); + _t3 = _mm512_permutex_epi64(_t3, _MM_SHUFFLE(3, 1, 2, 0)); + _t0 = _mm512_shuffle_i32x4(_t0, _t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm512_shuffle_i32x4(_t1, _t1, _MM_SHUFFLE(3, 1, 2, 0)); + _t2 = _mm512_shuffle_i32x4(_t2, _t2, _MM_SHUFFLE(3, 1, 2, 0)); + _t3 = _mm512_shuffle_i32x4(_t3, _t3, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm512_storeu_si512((__m512i*)pp, _t0); + _mm512_storeu_si512((__m512i*)(pp + 64), _t1); + _mm512_storeu_si512((__m512i*)(pp + 128), _t2); + _mm512_storeu_si512((__m512i*)(pp + 192), _t3); + + pp += 256; + p0 += B_hstep * 16; + } +#endif // __AVX512VNNI__ + } + if (elempack == 8) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + _p4 = _mm512_mul_ps(_p4, _scale); + _p5 = _mm512_mul_ps(_p5, _scale); + _p6 = _mm512_mul_ps(_p6, _scale); + _p7 = _mm512_mul_ps(_p7, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp2, _pp4, _pp6); + __m512i _t1 = combine4x4_epi32(_pp1, _pp3, _pp5, _pp7); + + __m512i _t2 = _mm512_unpacklo_epi32(_t0, _t1); + __m512i _t3 = _mm512_unpackhi_epi32(_t0, _t1); + __m512i _ppa = _mm512_unpacklo_epi32(_t2, _t3); + __m512i _ppb = _mm512_unpackhi_epi32(_t2, _t3); + + _ppa = _mm512_add_epi8(_ppa, _v127); + _ppb = _mm512_add_epi8(_ppb, _v127); + + _mm512_storeu_si512((__m512i*)pp, _ppa); + _mm512_storeu_si512((__m512i*)(pp + 64), _ppb); pp += 128; p0 += B_hstep * 8; @@ -8897,141 +8199,47 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #else // __AVX512VNNI__ for (; kk + 7 < max_kk; kk += 8) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[8] * scale); - pp[3] = float2int8(p0[9] * scale); - pp[4] = float2int8(p0[16] * scale); - pp[5] = float2int8(p0[17] * scale); - pp[6] = float2int8(p0[24] * scale); - pp[7] = float2int8(p0[25] * scale); - pp[8] = float2int8(p0[32] * scale); - pp[9] = float2int8(p0[33] * scale); - pp[10] = float2int8(p0[40] * scale); - pp[11] = float2int8(p0[41] * scale); - pp[12] = float2int8(p0[48] * scale); - pp[13] = float2int8(p0[49] * scale); - pp[14] = float2int8(p0[56] * scale); - pp[15] = float2int8(p0[57] * scale); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + _p4 = _mm512_mul_ps(_p4, _scale); + _p5 = _mm512_mul_ps(_p5, _scale); + _p6 = _mm512_mul_ps(_p6, _scale); + _p7 = _mm512_mul_ps(_p7, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp2, _pp4, _pp6); + __m512i _t1 = combine4x4_epi32(_pp1, _pp3, _pp5, _pp7); + + __m512i _t2 = _mm512_unpacklo_epi16(_t0, _t1); + __m512i _t3 = _mm512_unpackhi_epi16(_t0, _t1); + _t0 = _mm512_unpacklo_epi16(_t2, _t3); + _t1 = _mm512_unpackhi_epi16(_t2, _t3); + _t0 = _mm512_permutex_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm512_permutex_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0)); + __m512i _ppa = _mm512_shuffle_i32x4(_t0, _t0, _MM_SHUFFLE(3, 1, 2, 0)); + __m512i _ppb = _mm512_shuffle_i32x4(_t1, _t1, _MM_SHUFFLE(3, 1, 2, 0)); - pp[16 + 0] = float2int8(p0[64 + 0] * scale); - pp[16 + 1] = float2int8(p0[64 + 1] * scale); - pp[16 + 2] = float2int8(p0[64 + 8] * scale); - pp[16 + 3] = float2int8(p0[64 + 9] * scale); - pp[16 + 4] = float2int8(p0[64 + 16] * scale); - pp[16 + 5] = float2int8(p0[64 + 17] * scale); - pp[16 + 6] = float2int8(p0[64 + 24] * scale); - pp[16 + 7] = float2int8(p0[64 + 25] * scale); - pp[16 + 8] = float2int8(p0[64 + 32] * scale); - pp[16 + 9] = float2int8(p0[64 + 33] * scale); - pp[16 + 10] = float2int8(p0[64 + 40] * scale); - pp[16 + 11] = float2int8(p0[64 + 41] * scale); - pp[16 + 12] = float2int8(p0[64 + 48] * scale); - pp[16 + 13] = float2int8(p0[64 + 49] * scale); - pp[16 + 14] = float2int8(p0[64 + 56] * scale); - pp[16 + 15] = float2int8(p0[64 + 57] * scale); - - pp[32 + 0] = float2int8(p0[2] * scale); - pp[32 + 1] = float2int8(p0[3] * scale); - pp[32 + 2] = float2int8(p0[10] * scale); - pp[32 + 3] = float2int8(p0[11] * scale); - pp[32 + 4] = float2int8(p0[18] * scale); - pp[32 + 5] = float2int8(p0[19] * scale); - pp[32 + 6] = float2int8(p0[26] * scale); - pp[32 + 7] = float2int8(p0[27] * scale); - pp[32 + 8] = float2int8(p0[34] * scale); - pp[32 + 9] = float2int8(p0[35] * scale); - pp[32 + 10] = float2int8(p0[42] * scale); - pp[32 + 11] = float2int8(p0[43] * scale); - pp[32 + 12] = float2int8(p0[50] * scale); - pp[32 + 13] = float2int8(p0[51] * scale); - pp[32 + 14] = float2int8(p0[58] * scale); - pp[32 + 15] = float2int8(p0[59] * scale); - - pp[48 + 0] = float2int8(p0[64 + 2] * scale); - pp[48 + 1] = float2int8(p0[64 + 3] * scale); - pp[48 + 2] = float2int8(p0[64 + 10] * scale); - pp[48 + 3] = float2int8(p0[64 + 11] * scale); - pp[48 + 4] = float2int8(p0[64 + 18] * scale); - pp[48 + 5] = float2int8(p0[64 + 19] * scale); - pp[48 + 6] = float2int8(p0[64 + 26] * scale); - pp[48 + 7] = float2int8(p0[64 + 27] * scale); - pp[48 + 8] = float2int8(p0[64 + 34] * scale); - pp[48 + 9] = float2int8(p0[64 + 35] * scale); - pp[48 + 10] = float2int8(p0[64 + 42] * scale); - pp[48 + 11] = float2int8(p0[64 + 43] * scale); - pp[48 + 12] = float2int8(p0[64 + 50] * scale); - pp[48 + 13] = float2int8(p0[64 + 51] * scale); - pp[48 + 14] = float2int8(p0[64 + 58] * scale); - pp[48 + 15] = float2int8(p0[64 + 59] * scale); - - pp[64 + 0] = float2int8(p0[4] * scale); - pp[64 + 1] = float2int8(p0[5] * scale); - pp[64 + 2] = float2int8(p0[12] * scale); - pp[64 + 3] = float2int8(p0[13] * scale); - pp[64 + 4] = float2int8(p0[20] * scale); - pp[64 + 5] = float2int8(p0[21] * scale); - pp[64 + 6] = float2int8(p0[28] * scale); - pp[64 + 7] = float2int8(p0[29] * scale); - pp[64 + 8] = float2int8(p0[36] * scale); - pp[64 + 9] = float2int8(p0[37] * scale); - pp[64 + 10] = float2int8(p0[44] * scale); - pp[64 + 11] = float2int8(p0[45] * scale); - pp[64 + 12] = float2int8(p0[52] * scale); - pp[64 + 13] = float2int8(p0[53] * scale); - pp[64 + 14] = float2int8(p0[60] * scale); - pp[64 + 15] = float2int8(p0[61] * scale); - - pp[80 + 0] = float2int8(p0[64 + 4] * scale); - pp[80 + 1] = float2int8(p0[64 + 5] * scale); - pp[80 + 2] = float2int8(p0[64 + 12] * scale); - pp[80 + 3] = float2int8(p0[64 + 13] * scale); - pp[80 + 4] = float2int8(p0[64 + 20] * scale); - pp[80 + 5] = float2int8(p0[64 + 21] * scale); - pp[80 + 6] = float2int8(p0[64 + 28] * scale); - pp[80 + 7] = float2int8(p0[64 + 29] * scale); - pp[80 + 8] = float2int8(p0[64 + 36] * scale); - pp[80 + 9] = float2int8(p0[64 + 37] * scale); - pp[80 + 10] = float2int8(p0[64 + 44] * scale); - pp[80 + 11] = float2int8(p0[64 + 45] * scale); - pp[80 + 12] = float2int8(p0[64 + 52] * scale); - pp[80 + 13] = float2int8(p0[64 + 53] * scale); - pp[80 + 14] = float2int8(p0[64 + 60] * scale); - pp[80 + 15] = float2int8(p0[64 + 61] * scale); - - pp[96 + 0] = float2int8(p0[6] * scale); - pp[96 + 1] = float2int8(p0[7] * scale); - pp[96 + 2] = float2int8(p0[14] * scale); - pp[96 + 3] = float2int8(p0[15] * scale); - pp[96 + 4] = float2int8(p0[22] * scale); - pp[96 + 5] = float2int8(p0[23] * scale); - pp[96 + 6] = float2int8(p0[30] * scale); - pp[96 + 7] = float2int8(p0[31] * scale); - pp[96 + 8] = float2int8(p0[38] * scale); - pp[96 + 9] = float2int8(p0[39] * scale); - pp[96 + 10] = float2int8(p0[46] * scale); - pp[96 + 11] = float2int8(p0[47] * scale); - pp[96 + 12] = float2int8(p0[54] * scale); - pp[96 + 13] = float2int8(p0[55] * scale); - pp[96 + 14] = float2int8(p0[62] * scale); - pp[96 + 15] = float2int8(p0[63] * scale); - - pp[112 + 0] = float2int8(p0[64 + 6] * scale); - pp[112 + 1] = float2int8(p0[64 + 7] * scale); - pp[112 + 2] = float2int8(p0[64 + 14] * scale); - pp[112 + 3] = float2int8(p0[64 + 15] * scale); - pp[112 + 4] = float2int8(p0[64 + 22] * scale); - pp[112 + 5] = float2int8(p0[64 + 23] * scale); - pp[112 + 6] = float2int8(p0[64 + 30] * scale); - pp[112 + 7] = float2int8(p0[64 + 31] * scale); - pp[112 + 8] = float2int8(p0[64 + 38] * scale); - pp[112 + 9] = float2int8(p0[64 + 39] * scale); - pp[112 + 10] = float2int8(p0[64 + 46] * scale); - pp[112 + 11] = float2int8(p0[64 + 47] * scale); - pp[112 + 12] = float2int8(p0[64 + 54] * scale); - pp[112 + 13] = float2int8(p0[64 + 55] * scale); - pp[112 + 14] = float2int8(p0[64 + 62] * scale); - pp[112 + 15] = float2int8(p0[64 + 63] * scale); + _mm512_storeu_si512((__m512i*)pp, _ppa); + _mm512_storeu_si512((__m512i*)(pp + 64), _ppb); pp += 128; p0 += B_hstep * 8; @@ -9044,71 +8252,26 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #if __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[1] * scale) + 127; - pp[2] = float2int8(p0[2] * scale) + 127; - pp[3] = float2int8(p0[3] * scale) + 127; - pp[4] = float2int8(p0[4] * scale) + 127; - pp[5] = float2int8(p0[5] * scale) + 127; - pp[6] = float2int8(p0[6] * scale) + 127; - pp[7] = float2int8(p0[7] * scale) + 127; - pp[8] = float2int8(p0[8] * scale) + 127; - pp[9] = float2int8(p0[9] * scale) + 127; - pp[10] = float2int8(p0[10] * scale) + 127; - pp[11] = float2int8(p0[11] * scale) + 127; - pp[12] = float2int8(p0[12] * scale) + 127; - pp[13] = float2int8(p0[13] * scale) + 127; - pp[14] = float2int8(p0[14] * scale) + 127; - pp[15] = float2int8(p0[15] * scale) + 127; - pp[16] = float2int8(p0[16] * scale) + 127; - pp[17] = float2int8(p0[17] * scale) + 127; - pp[18] = float2int8(p0[18] * scale) + 127; - pp[19] = float2int8(p0[19] * scale) + 127; - pp[20] = float2int8(p0[20] * scale) + 127; - pp[21] = float2int8(p0[21] * scale) + 127; - pp[22] = float2int8(p0[22] * scale) + 127; - pp[23] = float2int8(p0[23] * scale) + 127; - pp[24] = float2int8(p0[24] * scale) + 127; - pp[25] = float2int8(p0[25] * scale) + 127; - pp[26] = float2int8(p0[26] * scale) + 127; - pp[27] = float2int8(p0[27] * scale) + 127; - pp[28] = float2int8(p0[28] * scale) + 127; - pp[29] = float2int8(p0[29] * scale) + 127; - pp[30] = float2int8(p0[30] * scale) + 127; - pp[31] = float2int8(p0[31] * scale) + 127; + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _pp = _mm512_add_epi8(_pp, _v127); - pp[32 + 0] = float2int8(p0[32 + 0] * scale) + 127; - pp[32 + 1] = float2int8(p0[32 + 1] * scale) + 127; - pp[32 + 2] = float2int8(p0[32 + 2] * scale) + 127; - pp[32 + 3] = float2int8(p0[32 + 3] * scale) + 127; - pp[32 + 4] = float2int8(p0[32 + 4] * scale) + 127; - pp[32 + 5] = float2int8(p0[32 + 5] * scale) + 127; - pp[32 + 6] = float2int8(p0[32 + 6] * scale) + 127; - pp[32 + 7] = float2int8(p0[32 + 7] * scale) + 127; - pp[32 + 8] = float2int8(p0[32 + 8] * scale) + 127; - pp[32 + 9] = float2int8(p0[32 + 9] * scale) + 127; - pp[32 + 10] = float2int8(p0[32 + 10] * scale) + 127; - pp[32 + 11] = float2int8(p0[32 + 11] * scale) + 127; - pp[32 + 12] = float2int8(p0[32 + 12] * scale) + 127; - pp[32 + 13] = float2int8(p0[32 + 13] * scale) + 127; - pp[32 + 14] = float2int8(p0[32 + 14] * scale) + 127; - pp[32 + 15] = float2int8(p0[32 + 15] * scale) + 127; - pp[32 + 16] = float2int8(p0[32 + 16] * scale) + 127; - pp[32 + 17] = float2int8(p0[32 + 17] * scale) + 127; - pp[32 + 18] = float2int8(p0[32 + 18] * scale) + 127; - pp[32 + 19] = float2int8(p0[32 + 19] * scale) + 127; - pp[32 + 20] = float2int8(p0[32 + 20] * scale) + 127; - pp[32 + 21] = float2int8(p0[32 + 21] * scale) + 127; - pp[32 + 22] = float2int8(p0[32 + 22] * scale) + 127; - pp[32 + 23] = float2int8(p0[32 + 23] * scale) + 127; - pp[32 + 24] = float2int8(p0[32 + 24] * scale) + 127; - pp[32 + 25] = float2int8(p0[32 + 25] * scale) + 127; - pp[32 + 26] = float2int8(p0[32 + 26] * scale) + 127; - pp[32 + 27] = float2int8(p0[32 + 27] * scale) + 127; - pp[32 + 28] = float2int8(p0[32 + 28] * scale) + 127; - pp[32 + 29] = float2int8(p0[32 + 29] * scale) + 127; - pp[32 + 30] = float2int8(p0[32 + 30] * scale) + 127; - pp[32 + 31] = float2int8(p0[32 + 31] * scale) + 127; + _mm512_storeu_si512((__m512i*)pp, _pp); pp += 64; p0 += B_hstep * 4; @@ -9116,73 +8279,33 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #else // __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[4] * scale); - pp[3] = float2int8(p0[5] * scale); - pp[4] = float2int8(p0[8] * scale); - pp[5] = float2int8(p0[9] * scale); - pp[6] = float2int8(p0[12] * scale); - pp[7] = float2int8(p0[13] * scale); - pp[8] = float2int8(p0[16] * scale); - pp[9] = float2int8(p0[17] * scale); - pp[10] = float2int8(p0[20] * scale); - pp[11] = float2int8(p0[21] * scale); - pp[12] = float2int8(p0[24] * scale); - pp[13] = float2int8(p0[25] * scale); - pp[14] = float2int8(p0[28] * scale); - pp[15] = float2int8(p0[29] * scale); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + __m256i _pp02 = combine4x2_epi32(_pp0, _pp2); + __m256i _pp13 = combine4x2_epi32(_pp1, _pp3); - pp[16 + 0] = float2int8(p0[32 + 0] * scale); - pp[16 + 1] = float2int8(p0[32 + 1] * scale); - pp[16 + 2] = float2int8(p0[32 + 4] * scale); - pp[16 + 3] = float2int8(p0[32 + 5] * scale); - pp[16 + 4] = float2int8(p0[32 + 8] * scale); - pp[16 + 5] = float2int8(p0[32 + 9] * scale); - pp[16 + 6] = float2int8(p0[32 + 12] * scale); - pp[16 + 7] = float2int8(p0[32 + 13] * scale); - pp[16 + 8] = float2int8(p0[32 + 16] * scale); - pp[16 + 9] = float2int8(p0[32 + 17] * scale); - pp[16 + 10] = float2int8(p0[32 + 20] * scale); - pp[16 + 11] = float2int8(p0[32 + 21] * scale); - pp[16 + 12] = float2int8(p0[32 + 24] * scale); - pp[16 + 13] = float2int8(p0[32 + 25] * scale); - pp[16 + 14] = float2int8(p0[32 + 28] * scale); - pp[16 + 15] = float2int8(p0[32 + 29] * scale); - - pp[32 + 0] = float2int8(p0[2] * scale); - pp[32 + 1] = float2int8(p0[3] * scale); - pp[32 + 2] = float2int8(p0[6] * scale); - pp[32 + 3] = float2int8(p0[7] * scale); - pp[32 + 4] = float2int8(p0[10] * scale); - pp[32 + 5] = float2int8(p0[11] * scale); - pp[32 + 6] = float2int8(p0[14] * scale); - pp[32 + 7] = float2int8(p0[15] * scale); - pp[32 + 8] = float2int8(p0[18] * scale); - pp[32 + 9] = float2int8(p0[19] * scale); - pp[32 + 10] = float2int8(p0[22] * scale); - pp[32 + 11] = float2int8(p0[23] * scale); - pp[32 + 12] = float2int8(p0[26] * scale); - pp[32 + 13] = float2int8(p0[27] * scale); - pp[32 + 14] = float2int8(p0[30] * scale); - pp[32 + 15] = float2int8(p0[31] * scale); - - pp[48 + 0] = float2int8(p0[32 + 2] * scale); - pp[48 + 1] = float2int8(p0[32 + 3] * scale); - pp[48 + 2] = float2int8(p0[32 + 6] * scale); - pp[48 + 3] = float2int8(p0[32 + 7] * scale); - pp[48 + 4] = float2int8(p0[32 + 10] * scale); - pp[48 + 5] = float2int8(p0[32 + 11] * scale); - pp[48 + 6] = float2int8(p0[32 + 14] * scale); - pp[48 + 7] = float2int8(p0[32 + 15] * scale); - pp[48 + 8] = float2int8(p0[32 + 18] * scale); - pp[48 + 9] = float2int8(p0[32 + 19] * scale); - pp[48 + 10] = float2int8(p0[32 + 22] * scale); - pp[48 + 11] = float2int8(p0[32 + 23] * scale); - pp[48 + 12] = float2int8(p0[32 + 26] * scale); - pp[48 + 13] = float2int8(p0[32 + 27] * scale); - pp[48 + 14] = float2int8(p0[32 + 30] * scale); - pp[48 + 15] = float2int8(p0[32 + 31] * scale); + __m256i _t0 = _mm256_unpacklo_epi16(_pp02, _pp13); + __m256i _t1 = _mm256_unpackhi_epi16(_pp02, _pp13); + __m256i _t2 = _mm256_unpacklo_epi16(_t0, _t1); + __m256i _t3 = _mm256_unpackhi_epi16(_t0, _t1); + _t0 = _mm256_unpacklo_epi16(_t2, _t3); + _t1 = _mm256_unpackhi_epi16(_t2, _t3); + + _mm256_storeu_si256((__m256i*)pp, _t0); + _mm256_storeu_si256((__m256i*)(pp + 32), _t1); pp += 64; p0 += B_hstep * 4; @@ -9195,131 +8318,64 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #if __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[B_hstep] * scale) + 127; - pp[2] = float2int8(p0[B_hstep * 2] * scale) + 127; - pp[3] = float2int8(p0[B_hstep * 3] * scale) + 127; - pp[4] = float2int8(p0[1] * scale) + 127; - pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; - pp[6] = float2int8(p0[B_hstep * 2 + 1] * scale) + 127; - pp[7] = float2int8(p0[B_hstep * 3 + 1] * scale) + 127; - pp[8] = float2int8(p0[2] * scale) + 127; - pp[9] = float2int8(p0[B_hstep + 2] * scale) + 127; - pp[10] = float2int8(p0[B_hstep * 2 + 2] * scale) + 127; - pp[11] = float2int8(p0[B_hstep * 3 + 2] * scale) + 127; - pp[12] = float2int8(p0[3] * scale) + 127; - pp[13] = float2int8(p0[B_hstep + 3] * scale) + 127; - pp[14] = float2int8(p0[B_hstep * 2 + 3] * scale) + 127; - pp[15] = float2int8(p0[B_hstep * 3 + 3] * scale) + 127; - pp[16] = float2int8(p0[4] * scale) + 127; - pp[17] = float2int8(p0[B_hstep + 4] * scale) + 127; - pp[18] = float2int8(p0[B_hstep * 2 + 4] * scale) + 127; - pp[19] = float2int8(p0[B_hstep * 3 + 4] * scale) + 127; - pp[20] = float2int8(p0[5] * scale) + 127; - pp[21] = float2int8(p0[B_hstep + 5] * scale) + 127; - pp[22] = float2int8(p0[B_hstep * 2 + 5] * scale) + 127; - pp[23] = float2int8(p0[B_hstep * 3 + 5] * scale) + 127; - pp[24] = float2int8(p0[6] * scale) + 127; - pp[25] = float2int8(p0[B_hstep + 6] * scale) + 127; - pp[26] = float2int8(p0[B_hstep * 2 + 6] * scale) + 127; - pp[27] = float2int8(p0[B_hstep * 3 + 6] * scale) + 127; - pp[28] = float2int8(p0[7] * scale) + 127; - pp[29] = float2int8(p0[B_hstep + 7] * scale) + 127; - pp[30] = float2int8(p0[B_hstep * 2 + 7] * scale) + 127; - pp[31] = float2int8(p0[B_hstep * 3 + 7] * scale) + 127; + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + B_hstep); + __m512 _p2 = _mm512_loadu_ps(p0 + B_hstep * 2); + __m512 _p3 = _mm512_loadu_ps(p0 + B_hstep * 3); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _pp = _mm512_add_epi8(_pp, _v127); + + _mm512_storeu_si512((__m512i*)pp, _pp); - pp[32 + 0] = float2int8(p0[8] * scale) + 127; - pp[32 + 1] = float2int8(p0[B_hstep + 8] * scale) + 127; - pp[32 + 2] = float2int8(p0[B_hstep * 2 + 8] * scale) + 127; - pp[32 + 3] = float2int8(p0[B_hstep * 3 + 8] * scale) + 127; - pp[32 + 4] = float2int8(p0[9] * scale) + 127; - pp[32 + 5] = float2int8(p0[B_hstep + 9] * scale) + 127; - pp[32 + 6] = float2int8(p0[B_hstep * 2 + 9] * scale) + 127; - pp[32 + 7] = float2int8(p0[B_hstep * 3 + 9] * scale) + 127; - pp[32 + 8] = float2int8(p0[10] * scale) + 127; - pp[32 + 9] = float2int8(p0[B_hstep + 10] * scale) + 127; - pp[32 + 10] = float2int8(p0[B_hstep * 2 + 10] * scale) + 127; - pp[32 + 11] = float2int8(p0[B_hstep * 3 + 10] * scale) + 127; - pp[32 + 12] = float2int8(p0[11] * scale) + 127; - pp[32 + 13] = float2int8(p0[B_hstep + 11] * scale) + 127; - pp[32 + 14] = float2int8(p0[B_hstep * 2 + 11] * scale) + 127; - pp[32 + 15] = float2int8(p0[B_hstep * 3 + 11] * scale) + 127; - pp[32 + 16] = float2int8(p0[12] * scale) + 127; - pp[32 + 17] = float2int8(p0[B_hstep + 12] * scale) + 127; - pp[32 + 18] = float2int8(p0[B_hstep * 2 + 12] * scale) + 127; - pp[32 + 19] = float2int8(p0[B_hstep * 3 + 12] * scale) + 127; - pp[32 + 20] = float2int8(p0[13] * scale) + 127; - pp[32 + 21] = float2int8(p0[B_hstep + 13] * scale) + 127; - pp[32 + 22] = float2int8(p0[B_hstep * 2 + 13] * scale) + 127; - pp[32 + 23] = float2int8(p0[B_hstep * 3 + 13] * scale) + 127; - pp[32 + 24] = float2int8(p0[14] * scale) + 127; - pp[32 + 25] = float2int8(p0[B_hstep + 14] * scale) + 127; - pp[32 + 26] = float2int8(p0[B_hstep * 2 + 14] * scale) + 127; - pp[32 + 27] = float2int8(p0[B_hstep * 3 + 14] * scale) + 127; - pp[32 + 28] = float2int8(p0[15] * scale) + 127; - pp[32 + 29] = float2int8(p0[B_hstep + 15] * scale) + 127; - pp[32 + 30] = float2int8(p0[B_hstep * 2 + 15] * scale) + 127; - pp[32 + 31] = float2int8(p0[B_hstep * 3 + 15] * scale) + 127; pp += 64; p0 += B_hstep * 4; } #endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[B_hstep] * scale); - pp[2] = float2int8(p0[1] * scale); - pp[3] = float2int8(p0[B_hstep + 1] * scale); - pp[4] = float2int8(p0[2] * scale); - pp[5] = float2int8(p0[B_hstep + 2] * scale); - pp[6] = float2int8(p0[3] * scale); - pp[7] = float2int8(p0[B_hstep + 3] * scale); - pp[8] = float2int8(p0[4] * scale); - pp[9] = float2int8(p0[B_hstep + 4] * scale); - pp[10] = float2int8(p0[5] * scale); - pp[11] = float2int8(p0[B_hstep + 5] * scale); - pp[12] = float2int8(p0[6] * scale); - pp[13] = float2int8(p0[B_hstep + 6] * scale); - pp[14] = float2int8(p0[7] * scale); - pp[15] = float2int8(p0[B_hstep + 7] * scale); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + B_hstep); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + // transpose16x2_epi8 + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 16), _t1); - pp[16 + 0] = float2int8(p0[8] * scale); - pp[16 + 1] = float2int8(p0[B_hstep + 8] * scale); - pp[16 + 2] = float2int8(p0[9] * scale); - pp[16 + 3] = float2int8(p0[B_hstep + 9] * scale); - pp[16 + 4] = float2int8(p0[10] * scale); - pp[16 + 5] = float2int8(p0[B_hstep + 10] * scale); - pp[16 + 6] = float2int8(p0[11] * scale); - pp[16 + 7] = float2int8(p0[B_hstep + 11] * scale); - pp[16 + 8] = float2int8(p0[12] * scale); - pp[16 + 9] = float2int8(p0[B_hstep + 12] * scale); - pp[16 + 10] = float2int8(p0[13] * scale); - pp[16 + 11] = float2int8(p0[B_hstep + 13] * scale); - pp[16 + 12] = float2int8(p0[14] * scale); - pp[16 + 13] = float2int8(p0[B_hstep + 14] * scale); - pp[16 + 14] = float2int8(p0[15] * scale); - pp[16 + 15] = float2int8(p0[B_hstep + 15] * scale); pp += 32; p0 += B_hstep * 2; } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); - pp[4] = float2int8(p0[4] * scale); - pp[5] = float2int8(p0[5] * scale); - pp[6] = float2int8(p0[6] * scale); - pp[7] = float2int8(p0[7] * scale); - pp[8] = float2int8(p0[8] * scale); - pp[9] = float2int8(p0[9] * scale); - pp[10] = float2int8(p0[10] * scale); - pp[11] = float2int8(p0[11] * scale); - pp[12] = float2int8(p0[12] * scale); - pp[13] = float2int8(p0[13] * scale); - pp[14] = float2int8(p0[14] * scale); - pp[15] = float2int8(p0[15] * scale); + __m512 _p = _mm512_loadu_ps(p0); + + _p = _mm512_mul_ps(_p, _scale); + + __m128i _pp = float2int8_avx512(_p); + + _mm_storeu_si128((__m128i*)pp, _pp); + pp += 16; p0 += B_hstep; }