Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SVE support for Graviton4 #364

Merged
merged 1 commit into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 13 additions & 36 deletions arch/arm64-sve/rpo/library.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#include <stddef.h>
#include <arm_sve.h>
#include "library.h"
#include "rpo_hash.h"
#include "rpo_hash_128bit.h"
#include "rpo_hash_256bit.h"

// The STATE_WIDTH of RPO hash is 12x u64 elements.
// The current generation of SVE-enabled processors - Neoverse V1
Expand Down Expand Up @@ -31,48 +32,24 @@

bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector

if (vl != 4) {

if (vl == 2) {
return add_constants_and_apply_sbox_128(state, constants);
} else if (vl == 4) {
return add_constants_and_apply_sbox_256(state, constants);
} else {
return false;
}

svbool_t ptrue = svptrue_b64();

svuint64_t state1 = svld1(ptrue, state + 0*vl);
svuint64_t state2 = svld1(ptrue, state + 1*vl);

svuint64_t const1 = svld1(ptrue, constants + 0*vl);
svuint64_t const2 = svld1(ptrue, constants + 1*vl);

add_constants(ptrue, &state1, &const1, &state2, &const2, state+8, constants+8);
apply_sbox(ptrue, &state1, &state2, state+8);

svst1(ptrue, state + 0*vl, state1);
svst1(ptrue, state + 1*vl, state2);

return true;
}

bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector

if (vl != 4) {
if (vl == 2) {
return add_constants_and_apply_inv_sbox_128(state, constants);
} else if (vl == 4) {
return add_constants_and_apply_inv_sbox_256(state, constants);
} else {
return false;
}

svbool_t ptrue = svptrue_b64();

svuint64_t state1 = svld1(ptrue, state + 0 * vl);
svuint64_t state2 = svld1(ptrue, state + 1 * vl);

svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);

add_constants(ptrue, &state1, &const1, &state2, &const2, state + 8, constants + 8);
apply_inv_sbox(ptrue, &state1, &state2, state + 8);

svst1(ptrue, state + 0 * vl, state1);
svst1(ptrue, state + 1 * vl, state2);

return true;
}
318 changes: 318 additions & 0 deletions arch/arm64-sve/rpo/rpo_hash_128bit.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
#ifndef RPO_SVE_RPO_HASH_128_H
#define RPO_SVE_RPO_HASH_128_H

#include <arm_sve.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>

#define STATE_WIDTH 12

#define COPY_128(NAME, VIN1, VIN2, VIN3, VIN4, SIN) \
svuint64_t NAME ## _1 = VIN1; \
svuint64_t NAME ## _2 = VIN2; \
svuint64_t NAME ## _3 = VIN3; \
svuint64_t NAME ## _4 = VIN4; \
uint64_t NAME ## _tail[4]; \
memcpy(NAME ## _tail, SIN, 4 * sizeof(uint64_t))

#define MULTIPLY_128(PRED, DEST, OP) \
mul_128(PRED, &DEST ## _1, &OP ## _1, &DEST ## _2, &OP ## _2, &DEST ## _3, &OP ## _3, &DEST ## _4, &OP ## _4, DEST ## _tail, OP ## _tail)

#define SQUARE_128(PRED, NAME) \
sq_128(PRED, &NAME ## _1, &NAME ## _2, &NAME ## _3, &NAME ## _4, NAME ## _tail)

#define SQUARE_DEST_128(PRED, DEST, SRC) \
COPY_128(DEST, SRC ## _1, SRC ## _2, SRC ## _3, SRC ## _4, SRC ## _tail); \
SQUARE_128(PRED, DEST);

#define POW_ACC_128(PRED, NAME, CNT, TAIL) \
for (size_t i = 0; i < CNT; i++) { \
SQUARE_128(PRED, NAME); \
} \
MULTIPLY_128(PRED, NAME, TAIL);

#define POW_ACC_DEST(PRED, DEST, CNT, HEAD, TAIL) \
COPY_128(DEST, HEAD ## _1, HEAD ## _2, HEAD ## _3, HEAD ## _4, HEAD ## _tail); \
POW_ACC_128(PRED, DEST, CNT, TAIL)

extern inline void add_constants_128(
svbool_t pg,
svuint64_t *state1,
svuint64_t *const1,
svuint64_t *state2,
svuint64_t *const2,
svuint64_t *state3,
svuint64_t *const3,
svuint64_t *state4,
svuint64_t *const4,

uint64_t *state_tail,
uint64_t *const_tail
) {
uint64_t Ms = 0xFFFFFFFF00000001ull;
svuint64_t Mv = svindex_u64(Ms, 0);

uint64_t p_1 = Ms - const_tail[0];
uint64_t p_2 = Ms - const_tail[1];
uint64_t p_3 = Ms - const_tail[2];
uint64_t p_4 = Ms - const_tail[3];

uint64_t x_1, x_2, x_3, x_4;
uint32_t adj_1 = -__builtin_sub_overflow(state_tail[0], p_1, &x_1);
uint32_t adj_2 = -__builtin_sub_overflow(state_tail[1], p_2, &x_2);
uint32_t adj_3 = -__builtin_sub_overflow(state_tail[2], p_3, &x_3);
uint32_t adj_4 = -__builtin_sub_overflow(state_tail[3], p_4, &x_4);

state_tail[0] = x_1 - (uint64_t)adj_1;
state_tail[1] = x_2 - (uint64_t)adj_2;
state_tail[2] = x_3 - (uint64_t)adj_3;
state_tail[3] = x_4 - (uint64_t)adj_4;

svuint64_t p1 = svsub_x(pg, Mv, *const1);
svuint64_t p2 = svsub_x(pg, Mv, *const2);
svuint64_t p3 = svsub_x(pg, Mv, *const3);
svuint64_t p4 = svsub_x(pg, Mv, *const4);

svuint64_t x1 = svsub_x(pg, *state1, p1);
svuint64_t x2 = svsub_x(pg, *state2, p2);
svuint64_t x3 = svsub_x(pg, *state3, p3);
svuint64_t x4 = svsub_x(pg, *state4, p4);

svbool_t pt1 = svcmplt_u64(pg, *state1, p1);
svbool_t pt2 = svcmplt_u64(pg, *state2, p2);
svbool_t pt3 = svcmplt_u64(pg, *state3, p3);
svbool_t pt4 = svcmplt_u64(pg, *state4, p4);

*state1 = svsub_m(pt1, x1, (uint32_t)-1);
*state2 = svsub_m(pt2, x2, (uint32_t)-1);
*state3 = svsub_m(pt3, x3, (uint32_t)-1);
*state4 = svsub_m(pt4, x4, (uint32_t)-1);
}

extern inline void mul_128(
svbool_t pg,
svuint64_t *r1,
const svuint64_t *op1,
svuint64_t *r2,
const svuint64_t *op2,
svuint64_t *r3,
const svuint64_t *op3,
svuint64_t *r4,
const svuint64_t *op4,
uint64_t *r_tail,
const uint64_t *op_tail
) {
__uint128_t x_1 = r_tail[0];
__uint128_t x_2 = r_tail[1];
__uint128_t x_3 = r_tail[2];
__uint128_t x_4 = r_tail[3];

x_1 *= (__uint128_t) op_tail[0];
x_2 *= (__uint128_t) op_tail[1];
x_3 *= (__uint128_t) op_tail[2];
x_4 *= (__uint128_t) op_tail[3];

uint64_t x0_1 = x_1;
uint64_t x0_2 = x_2;
uint64_t x0_3 = x_3;
uint64_t x0_4 = x_4;

svuint64_t l1 = svmul_x(pg, *r1, *op1);
svuint64_t l2 = svmul_x(pg, *r2, *op2);
svuint64_t l3 = svmul_x(pg, *r3, *op3);
svuint64_t l4 = svmul_x(pg, *r4, *op4);

uint64_t x1_1 = (x_1 >> 64);
uint64_t x1_2 = (x_2 >> 64);
uint64_t x1_3 = (x_3 >> 64);
uint64_t x1_4 = (x_4 >> 64);

uint64_t a_1, a_2, a_3, a_4;
uint64_t e_1 = __builtin_add_overflow(x0_1, (x0_1 << 32), &a_1);
uint64_t e_2 = __builtin_add_overflow(x0_2, (x0_2 << 32), &a_2);
uint64_t e_3 = __builtin_add_overflow(x0_3, (x0_3 << 32), &a_3);
uint64_t e_4 = __builtin_add_overflow(x0_4, (x0_4 << 32), &a_4);

svuint64_t ls1 = svlsl_x(pg, l1, 32);
svuint64_t ls2 = svlsl_x(pg, l2, 32);
svuint64_t ls3 = svlsl_x(pg, l3, 32);
svuint64_t ls4 = svlsl_x(pg, l4, 32);

svuint64_t a1 = svadd_x(pg, l1, ls1);
svuint64_t a2 = svadd_x(pg, l2, ls2);
svuint64_t a3 = svadd_x(pg, l3, ls3);
svuint64_t a4 = svadd_x(pg, l4, ls4);

svbool_t e1 = svcmplt(pg, a1, l1);
svbool_t e2 = svcmplt(pg, a2, l2);
svbool_t e3 = svcmplt(pg, a3, l3);
svbool_t e4 = svcmplt(pg, a4, l4);

svuint64_t as1 = svlsr_x(pg, a1, 32);
svuint64_t as2 = svlsr_x(pg, a2, 32);
svuint64_t as3 = svlsr_x(pg, a3, 32);
svuint64_t as4 = svlsr_x(pg, a4, 32);

svuint64_t b1 = svsub_x(pg, a1, as1);
svuint64_t b2 = svsub_x(pg, a2, as2);
svuint64_t b3 = svsub_x(pg, a3, as3);
svuint64_t b4 = svsub_x(pg, a4, as4);

b1 = svsub_m(e1, b1, 1);
b2 = svsub_m(e2, b2, 1);
b3 = svsub_m(e3, b3, 1);
b4 = svsub_m(e4, b4, 1);

uint64_t b_1 = a_1 - (a_1 >> 32) - e_1;
uint64_t b_2 = a_2 - (a_2 >> 32) - e_2;
uint64_t b_3 = a_3 - (a_3 >> 32) - e_3;
uint64_t b_4 = a_4 - (a_4 >> 32) - e_4;

uint64_t r_1, r_2, r_3, r_4;
uint32_t c_1 = __builtin_sub_overflow(x1_1, b_1, &r_1);
uint32_t c_2 = __builtin_sub_overflow(x1_2, b_2, &r_2);
uint32_t c_3 = __builtin_sub_overflow(x1_3, b_3, &r_3);
uint32_t c_4 = __builtin_sub_overflow(x1_4, b_4, &r_4);

svuint64_t h1 = svmulh_x(pg, *r1, *op1);
svuint64_t h2 = svmulh_x(pg, *r2, *op2);
svuint64_t h3 = svmulh_x(pg, *r3, *op3);
svuint64_t h4 = svmulh_x(pg, *r4, *op4);

svuint64_t tr1 = svsub_x(pg, h1, b1);
svuint64_t tr2 = svsub_x(pg, h2, b2);
svuint64_t tr3 = svsub_x(pg, h3, b3);
svuint64_t tr4 = svsub_x(pg, h4, b4);

svbool_t c1 = svcmplt_u64(pg, h1, b1);
svbool_t c2 = svcmplt_u64(pg, h2, b2);
svbool_t c3 = svcmplt_u64(pg, h3, b3);
svbool_t c4 = svcmplt_u64(pg, h4, b4);

*r1 = svsub_m(c1, tr1, (uint32_t) -1);
*r2 = svsub_m(c2, tr2, (uint32_t) -1);
*r3 = svsub_m(c3, tr3, (uint32_t) -1);
*r4 = svsub_m(c4, tr4, (uint32_t) -1);

uint32_t minus1_1 = 0 - c_1;
uint32_t minus1_2 = 0 - c_2;
uint32_t minus1_3 = 0 - c_3;
uint32_t minus1_4 = 0 - c_4;

r_tail[0] = r_1 - (uint64_t)minus1_1;
r_tail[1] = r_2 - (uint64_t)minus1_2;
r_tail[2] = r_3 - (uint64_t)minus1_3;
r_tail[3] = r_4 - (uint64_t)minus1_4;
}

extern inline void sq_128(svbool_t pg, svuint64_t *a, svuint64_t *b, svuint64_t *c, svuint64_t *d, uint64_t *e) {
mul_128(pg, a, a, b, b, c, c, d, d, e, e);
}

extern inline void apply_sbox_128(
svbool_t pg,
svuint64_t *state1,
svuint64_t *state2,
svuint64_t *state3,
svuint64_t *state4,
uint64_t *state_tail
) {
COPY_128(x, *state1, *state2, *state3, *state4, state_tail); // copy input to x
SQUARE_128(pg, x); // x contains input^2
mul_128(pg, state1, &x_1, state2, &x_2, state3, &x_3, state4, &x_4, state_tail, x_tail); // state contains input^3
SQUARE_128(pg, x); // x contains input^4
mul_128(pg, state1, &x_1, state2, &x_2, state3, &x_3, state4, &x_4, state_tail, x_tail); // state contains input^7
}

extern inline void apply_inv_sbox_128(
svbool_t pg,
svuint64_t *state1,
svuint64_t *state2,
svuint64_t *state3,
svuint64_t *state4,
uint64_t *state_tail
) {
// base^10
COPY_128(t1, *state1, *state2, *state3, *state4, state_tail);
SQUARE_128(pg, t1);

// base^100
SQUARE_DEST_128(pg, t2, t1);

// base^100100
POW_ACC_DEST(pg, t3, 3, t2, t2);

// base^100100100100
POW_ACC_DEST(pg, t4, 6, t3, t3);

// compute base^100100100100100100100100
POW_ACC_DEST(pg, t5, 12, t4, t4);

// compute base^100100100100100100100100100100
POW_ACC_DEST(pg, t6, 6, t5, t3);

// compute base^1001001001001001001001001001000100100100100100100100100100100
POW_ACC_DEST(pg, t7, 31, t6, t6);

// compute base^1001001001001001001001001001000110110110110110110110110110110111
SQUARE_128(pg, t7);
MULTIPLY_128(pg, t7, t6);
SQUARE_128(pg, t7);
SQUARE_128(pg, t7);
MULTIPLY_128(pg, t7, t1);
MULTIPLY_128(pg, t7, t2);
mul_128(pg, state1, &t7_1, state2, &t7_2, state3, &t7_3, state4, &t7_4, state_tail, t7_tail);
}

bool add_constants_and_apply_sbox_128(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = 2; // number of u64 numbers in one 128 bit SVE vector
svbool_t ptrue = svptrue_b64();

svuint64_t state1 = svld1(ptrue, state + 0 * vl);
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
svuint64_t state3 = svld1(ptrue, state + 2 * vl);
svuint64_t state4 = svld1(ptrue, state + 3 * vl);

svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
svuint64_t const3 = svld1(ptrue, constants + 2 * vl);
svuint64_t const4 = svld1(ptrue, constants + 3 * vl);

add_constants_128(ptrue, &state1, &const1, &state2, &const2, &state3, &const3, &state4, &const4, state + 8, constants + 8);
apply_sbox_128(ptrue, &state1, &state2, &state3, &state4, state + 8);

svst1(ptrue, state + 0 * vl, state1);
svst1(ptrue, state + 1 * vl, state2);
svst1(ptrue, state + 2 * vl, state3);
svst1(ptrue, state + 3 * vl, state4);

return true;
}

bool add_constants_and_apply_inv_sbox_128(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = 2; // number of u64 numbers in one 128 bit SVE vector
svbool_t ptrue = svptrue_b64();

svuint64_t state1 = svld1(ptrue, state + 0 * vl);
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
svuint64_t state3 = svld1(ptrue, state + 2 * vl);
svuint64_t state4 = svld1(ptrue, state + 3 * vl);

svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
svuint64_t const3 = svld1(ptrue, constants + 2 * vl);
svuint64_t const4 = svld1(ptrue, constants + 3 * vl);

add_constants_128(ptrue, &state1, &const1, &state2, &const2, &state3, &const3, &state4, &const4, state + 8, constants + 8);
apply_inv_sbox_128(ptrue, &state1, &state2, &state3, &state4, state + 8);

svst1(ptrue, state + 0 * vl, state1);
svst1(ptrue, state + 1 * vl, state2);
svst1(ptrue, state + 2 * vl, state3);
svst1(ptrue, state + 3 * vl, state4);

return true;
}

#endif //RPO_SVE_RPO_HASH_128_H
Loading
Loading