Skip to content

Commit

Permalink
upgrade vectorized Atbash cipher to use 256 bit length vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter-Juhasz committed Dec 27, 2024
1 parent bc55993 commit b75ea71
Showing 1 changed file with 38 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
using System;
using System.Composition;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;

using TVector = System.Runtime.Intrinsics.Vector128<short>;
using TVector = System.Runtime.Intrinsics.Vector256<short>;

namespace Science.Cryptography.Ciphers.Specialized;

Expand All @@ -18,12 +19,12 @@ public class AsciiAtbashCipher : ReciprocalCipher
private const int APlusZ = (int)('A' + 'Z');
private const int LowercaseAPlusZ = (int)('a' + 'z');

private static readonly TVector VectorOfAPlusZ = Vector128.Create((short)('A' + 'Z'));
private static readonly TVector VectorOfLowercaseAPlusZ = Vector128.Create((short)('a' + 'z'));
private static readonly TVector VectorOfAMinus1 = Vector128.Create((short)('A' - 1));
private static readonly TVector VectorOfZPlus1 = Vector128.Create((short)('Z' + 1));
private static readonly TVector VectorOfLowercaseAMinus1 = Vector128.Create((short)('a' - 1));
private static readonly TVector VectorOfLowercaseZPlus1 = Vector128.Create((short)('z' + 1));
private static readonly TVector VectorOfAPlusZ = Vector256.Create((short)('A' + 'Z'));
private static readonly TVector VectorOfLowercaseAPlusZ = Vector256.Create((short)('a' + 'z'));
private static readonly TVector VectorOfAMinus1 = Vector256.Create((short)('A' - 1));
private static readonly TVector VectorOfZPlus1 = Vector256.Create((short)('Z' + 1));
private static readonly TVector VectorOfLowercaseAMinus1 = Vector256.Create((short)('a' - 1));
private static readonly TVector VectorOfLowercaseZPlus1 = Vector256.Create((short)('z' + 1));

protected override void Crypt(ReadOnlySpan<char> text, Span<char> result, out int written)
{
Expand All @@ -32,34 +33,32 @@ protected override void Crypt(ReadOnlySpan<char> text, Span<char> result, out in
throw new ArgumentException("Size of output buffer is insufficient.", nameof(result));
}

if (Avx2.IsSupported)
var totalVectorizedLength = 0;

// process vectorized
if (Avx2.IsSupported && Vector256.IsHardwareAccelerated)
{
// process the vectorized input
var vectorCount = text.Length / TVector.Count;
var totalVectorizedLength = vectorCount * TVector.Count;
var vectorizedText = MemoryMarshal.Cast<char, short>(text);
var vectorizedResult = MemoryMarshal.Cast<char, short>(result);
totalVectorizedLength = vectorCount * TVector.Count;
var inputAsShort = MemoryMarshal.Cast<char, short>(text);
var outputAsShort = MemoryMarshal.Cast<char, short>(result);
for (int offset = 0; offset < totalVectorizedLength; offset += TVector.Count)
{
var input = Vector128.LoadUnsafe(ref MemoryMarshal.GetReference(vectorizedText[offset..]));
var output = CryptBlockAvx2(input);
output.StoreUnsafe(ref MemoryMarshal.GetReference(vectorizedResult[offset..]));
}

// process the remaining input
if (totalVectorizedLength < text.Length)
{
var remainingInput = text[totalVectorizedLength..];
var remainingOutput = result[totalVectorizedLength..];
CryptSlow(remainingInput, remainingOutput);
var inputBlock = Vector256.LoadUnsafe(ref MemoryMarshal.GetReference(inputAsShort[offset..]));
var outputBlock = CryptBlockAvx2(inputBlock);
outputBlock.StoreUnsafe(ref MemoryMarshal.GetReference(outputAsShort[offset..]));
}
}
else
{
CryptSlow(text, result);
}

written = text.Length;
// process the remaining input
if (totalVectorizedLength < text.Length)
{
var remainingInput = text[totalVectorizedLength..];
var remainingOutput = result[totalVectorizedLength..];
CryptSlow(remainingInput, remainingOutput);
}

written = text.Length;
}

internal static void CryptSlow(ReadOnlySpan<char> text, Span<char> result)
Expand All @@ -78,39 +77,23 @@ internal static void CryptSlow(ReadOnlySpan<char> text, Span<char> result)
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static TVector CryptBlockAvx2(TVector input)
private static TVector CryptBlockAvx2(TVector input)
{
// uppercase
var isUpperMask = Avx2.And(
Avx2.CompareGreaterThan(input, VectorOfAMinus1),
Avx2.CompareLessThan(input, VectorOfZPlus1)
var isUpperMask = Avx2.AndNot(
Avx2.CompareGreaterThan(input, VectorOfZPlus1),
Avx2.CompareGreaterThan(input, VectorOfAMinus1)
);
TVector transformedUppercase;
if (isUpperMask == TVector.Zero)
{
transformedUppercase = input;
}
else
{
transformedUppercase = Avx2.Subtract(VectorOfAPlusZ, input);
transformedUppercase = Avx2.BlendVariable(input, transformedUppercase, isUpperMask);
}
var transformedUppercase = Avx2.Subtract(VectorOfAPlusZ, input);
transformedUppercase = Avx2.BlendVariable(input, transformedUppercase, isUpperMask);

// lowercase
var isLowerMask = Avx2.And(
Avx2.CompareGreaterThan(input, VectorOfLowercaseAMinus1),
Avx2.CompareLessThan(input, VectorOfLowercaseZPlus1)
var isLowerMask = Avx2.AndNot(
Avx2.CompareGreaterThan(input, VectorOfLowercaseZPlus1),
Avx2.CompareGreaterThan(input, VectorOfLowercaseAMinus1)
);
TVector transformedLowercase;
if (isLowerMask == TVector.Zero)
{
transformedLowercase = input;
}
else
{
transformedLowercase = Avx2.Subtract(VectorOfLowercaseAPlusZ, input);
transformedLowercase = Avx2.BlendVariable(transformedLowercase, transformedLowercase, isLowerMask);
}
var transformedLowercase = Avx2.Subtract(VectorOfLowercaseAPlusZ, input);
transformedLowercase = Avx2.BlendVariable(transformedLowercase, transformedLowercase, isLowerMask);

// merge
var transformed = Avx2.BlendVariable(transformedUppercase, transformedLowercase, isLowerMask);
Expand Down

0 comments on commit b75ea71

Please sign in to comment.