From c0f1a0989538aa8ee3bef1b874c9342595eabc02 Mon Sep 17 00:00:00 2001 From: Daniel Lemire Date: Tue, 28 May 2024 00:13:27 -0400 Subject: [PATCH] fix: correct performance problem with arm function, it was due to Vector128.shuffle (DO NOT USE) --- README.md | 1 + benchmark/Benchmark.cs | 1 - src/UTF8.cs | 15 ++++++--------- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 6915db7..dba977e 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,7 @@ You can print the content of a vector register like so: ## Performance tips - Be careful: `Vector128.Shuffle` is not the same as `Ssse3.Shuffle` nor is `Vector128.Shuffle` the same as `Avx2.Shuffle`. Prefer the latter. +- Similarly `Vector128.Shuffle` is not the same as `AdvSimd.Arm64.VectorTableLookup`, use the latter. ## More reading diff --git a/benchmark/Benchmark.cs b/benchmark/Benchmark.cs index 9a56903..2861b3a 100644 --- a/benchmark/Benchmark.cs +++ b/benchmark/Benchmark.cs @@ -210,7 +210,6 @@ public unsafe void Utf8ValidationRealDataScalar() } } - [Benchmark] [BenchmarkCategory("arm64")] public unsafe void SIMDUtf8ValidationRealDataArm64() diff --git a/src/UTF8.cs b/src/UTF8.cs index 608d319..b957b95 100644 --- a/src/UTF8.cs +++ b/src/UTF8.cs @@ -790,7 +790,6 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust( int asciibytes = 0; // number of ascii bytes in the block (could also be called n1) int contbytes = 0; // number of continuation bytes in the block int n4 = 0; // number of 4-byte sequences that start in this block - for (; processedLength + 16 <= inputLength; processedLength += 16) { @@ -817,9 +816,10 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust( { // Contains non-ASCII characters, we need to do non-trivial processing Vector128 prev1 = AdvSimd.ExtractVector128(prevInputBlock, currentBlock, (byte)(16 - 1)); - Vector128 byte_1_high = Vector128.Shuffle(shuf1, AdvSimd.ShiftRightLogical(prev1.AsUInt16(), 4).AsByte() & v0f); - Vector128 byte_1_low = Vector128.Shuffle(shuf2, (prev1 & v0f)); - Vector128 byte_2_high = Vector128.Shuffle(shuf3, AdvSimd.ShiftRightLogical(currentBlock.AsUInt16(), 4).AsByte() & v0f); + // Vector128.Shuffle vs AdvSimd.Arm64.VectorTableLookup: prefer the latter!!! + Vector128 byte_1_high = AdvSimd.Arm64.VectorTableLookup(shuf1, AdvSimd.ShiftRightLogical(prev1.AsUInt16(), 4).AsByte() & v0f); + Vector128 byte_1_low = AdvSimd.Arm64.VectorTableLookup (shuf2, (prev1 & v0f)); + Vector128 byte_2_high = AdvSimd.Arm64.VectorTableLookup (shuf3, AdvSimd.ShiftRightLogical(currentBlock.AsUInt16(), 4).AsByte() & v0f); Vector128 sc = AdvSimd.And(AdvSimd.And(byte_1_high, byte_1_low), byte_2_high); Vector128 prev2 = AdvSimd.ExtractVector128(prevInputBlock, currentBlock, (byte)(16 - 2)); Vector128 prev3 = AdvSimd.ExtractVector128(prevInputBlock, currentBlock, (byte)(16 - 3)); @@ -849,13 +849,11 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust( } prevIncomplete = AdvSimd.SubtractSaturate(currentBlock, maxValue); Vector128 largestcont = Vector128.Create((sbyte)-65); // -65 => 0b10111111 - contbytes += 16 - AdvSimd.Arm64.AddAcross(AdvSimd.CompareGreaterThan(Vector128.AsSByte(currentBlock), largestcont)).ToScalar(); + contbytes += -AdvSimd.Arm64.AddAcross(AdvSimd.CompareLessThanOrEqual(Vector128.AsSByte(currentBlock), largestcont)).ToScalar(); Vector128 fourthByteMinusOne = Vector128.Create((byte)(0b11110000u - 1)); n4 += (int)(AdvSimd.Arm64.AddAcross(AdvSimd.SubtractSaturate(currentBlock, fourthByteMinusOne)).ToScalar()); } - - asciibytes -= (int)AdvSimd.Arm64.AddAcross(AdvSimd.CompareGreaterThanOrEqual(currentBlock, v80)).ToScalar(); - + asciibytes -= (sbyte)AdvSimd.Arm64.AddAcross(AdvSimd.CompareLessThan(currentBlock, v80)).ToScalar(); } int totalbyte = processedLength - start_point; @@ -886,7 +884,6 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust( } utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment + TailUtf16CodeUnitCountAdjustment; scalarCountAdjustment = TempScalarCountAdjustment + TailScalarCodeUnitCountAdjustment; - return pInputBuffer + inputLength; } public unsafe static byte* GetPointerToFirstInvalidByte(byte* pInputBuffer, int inputLength, out int Utf16CodeUnitCountAdjustment, out int ScalarCodeUnitCountAdjustment)