Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dalek NEON v7 #691

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion curve25519-dalek/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ legacy_compatibility = []
group = ["dep:group", "rand_core"]
group-bits = ["group", "ff/bits"]

[target.'cfg(all(not(curve25519_dalek_backend = "fiat"), not(curve25519_dalek_backend = "serial"), target_arch = "x86_64"))'.dependencies]
[target.'cfg(all(not(curve25519_dalek_backend = "fiat"), not(curve25519_dalek_backend = "serial"), any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm")))'.dependencies]
curve25519-dalek-derive = { version = "0.1", path = "../curve25519-dalek-derive" }

[lints.rust.unexpected_cfgs]
Expand Down
37 changes: 20 additions & 17 deletions curve25519-dalek/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#![deny(clippy::unwrap_used, dead_code)]

#[allow(non_camel_case_types)]
#[derive(PartialEq, Debug)]
#[derive(PartialEq, Clone, Copy, Debug)]
enum DalekBits {
Dalek32,
Dalek64,
Expand Down Expand Up @@ -51,31 +51,34 @@ fn main() {
}

// Backend overrides / defaults
let curve25519_dalek_backend =
match std::env::var("CARGO_CFG_CURVE25519_DALEK_BACKEND").as_deref() {
Ok("fiat") => "fiat",
Ok("serial") => "serial",
Ok("simd") => {
// simd can only be enabled on x86_64 & 64bit target_pointer_width
match is_capable_simd(&target_arch, curve25519_dalek_bits) {
let curve25519_dalek_backend = match std::env::var("CARGO_CFG_CURVE25519_DALEK_BACKEND")
.as_deref()
{
Ok("fiat") => "fiat",
Ok("serial") => "serial",
Ok("simd") => {
// simd can only be enabled on x86_64 & 64bit target_pointer_width, or
// armv7 & 32bit target_pointer_width
match is_capable_simd(&target_arch, curve25519_dalek_bits) {
true => "simd",
// If override is not possible this must result to compile error
// See: issues/532
false => panic!("Could not override curve25519_dalek_backend to simd"),
false => panic!("Could not override curve25519_dalek_backend to simd for arch {target_arch} and {curve25519_dalek_bits} bits"),
}
}
// default between serial / simd (if potentially capable)
_ => match is_capable_simd(&target_arch, curve25519_dalek_bits) {
true => "simd",
false => "serial",
},
};
}
// default between serial / simd (if potentially capable)
_ => match is_capable_simd(&target_arch, curve25519_dalek_bits) {
true => "simd",
false => "serial",
},
};
println!("cargo:rustc-cfg=curve25519_dalek_backend=\"{curve25519_dalek_backend}\"");
}

// Is the target arch & curve25519_dalek_bits potentially simd capable ?
fn is_capable_simd(arch: &str, bits: DalekBits) -> bool {
arch == "x86_64" && bits == DalekBits::Dalek64
(arch == "x86_64" || arch == "aarch64") && bits == DalekBits::Dalek64
|| arch == "arm" && bits == DalekBits::Dalek32
}

// Deterministic cfg(curve25519_dalek_bits) when this is not explicitly set.
Expand Down
74 changes: 54 additions & 20 deletions curve25519-dalek/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,22 @@ pub mod vector;

#[derive(Copy, Clone)]
enum BackendKind {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
Avx2,
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
Avx512,
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
Neon,
Serial,
}

#[inline]
fn get_selected_backend() -> BackendKind {
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
{
return BackendKind::Neon;
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
{
cpufeatures::new!(cpuid_avx512, "avx512ifma", "avx512vl");
let token_avx512: cpuid_avx512::InitToken = cpuid_avx512::init();
Expand All @@ -62,7 +68,7 @@ fn get_selected_backend() -> BackendKind {
}
}

#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
{
cpufeatures::new!(cpuid_avx2, "avx2");
let token_avx2: cpuid_avx2::InitToken = cpuid_avx2::init();
Expand All @@ -85,25 +91,30 @@ where
use crate::traits::VartimeMultiscalarMul;

match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
BackendKind::Avx2 =>
vector::scalar_mul::pippenger::spec_avx2::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
BackendKind::Avx512 =>
vector::scalar_mul::pippenger::spec_avx512ifma_avx512vl::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
BackendKind::Neon =>
vector::scalar_mul::pippenger::spec_neon::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
BackendKind::Serial =>
serial::scalar_mul::pippenger::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
}
}

#[cfg(feature = "alloc")]
pub(crate) enum VartimePrecomputedStraus {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
Avx2(vector::scalar_mul::precomputed_straus::spec_avx2::VartimePrecomputedStraus),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
Avx512ifma(
vector::scalar_mul::precomputed_straus::spec_avx512ifma_avx512vl::VartimePrecomputedStraus,
),
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
Neon(vector::scalar_mul::precomputed_straus::spec_neon::VartimePrecomputedStraus),
Scalar(serial::scalar_mul::precomputed_straus::VartimePrecomputedStraus),
}

Expand All @@ -117,12 +128,15 @@ impl VartimePrecomputedStraus {
use crate::traits::VartimePrecomputedMultiscalarMul;

match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
BackendKind::Avx2 =>
VartimePrecomputedStraus::Avx2(vector::scalar_mul::precomputed_straus::spec_avx2::VartimePrecomputedStraus::new(static_points)),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
BackendKind::Avx512 =>
VartimePrecomputedStraus::Avx512ifma(vector::scalar_mul::precomputed_straus::spec_avx512ifma_avx512vl::VartimePrecomputedStraus::new(static_points)),
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
BackendKind::Neon =>
VartimePrecomputedStraus::Neon(vector::scalar_mul::precomputed_straus::spec_neon::VartimePrecomputedStraus::new(static_points)),
BackendKind::Serial =>
VartimePrecomputedStraus::Scalar(serial::scalar_mul::precomputed_straus::VartimePrecomputedStraus::new(static_points))
}
Expand All @@ -144,18 +158,24 @@ impl VartimePrecomputedStraus {
use crate::traits::VartimePrecomputedMultiscalarMul;

match self {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
VartimePrecomputedStraus::Avx2(inner) => inner.optional_mixed_multiscalar_mul(
static_scalars,
dynamic_scalars,
dynamic_points,
),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
VartimePrecomputedStraus::Avx512ifma(inner) => inner.optional_mixed_multiscalar_mul(
static_scalars,
dynamic_scalars,
dynamic_points,
),
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
VartimePrecomputedStraus::Neon(inner) => inner.optional_mixed_multiscalar_mul(
static_scalars,
dynamic_scalars,
dynamic_points,
),
VartimePrecomputedStraus::Scalar(inner) => inner.optional_mixed_multiscalar_mul(
static_scalars,
dynamic_scalars,
Expand All @@ -177,16 +197,20 @@ where
use crate::traits::MultiscalarMul;

match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
BackendKind::Avx2 => {
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul::<I, J>(scalars, points)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
BackendKind::Avx512 => {
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul::<I, J>(
scalars, points,
)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
BackendKind::Neon => {
vector::scalar_mul::straus::spec_neon::Straus::multiscalar_mul::<I, J>(scalars, points)
}
BackendKind::Serial => {
serial::scalar_mul::straus::Straus::multiscalar_mul::<I, J>(scalars, points)
}
Expand All @@ -204,19 +228,25 @@ where
use crate::traits::VartimeMultiscalarMul;

match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
BackendKind::Avx2 => {
vector::scalar_mul::straus::spec_avx2::Straus::optional_multiscalar_mul::<I, J>(
scalars, points,
)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
BackendKind::Avx512 => {
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::optional_multiscalar_mul::<
I,
J,
>(scalars, points)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
BackendKind::Neon => {
vector::scalar_mul::straus::spec_neon::Straus::optional_multiscalar_mul::<I, J>(
scalars, points,
)
}
BackendKind::Serial => {
serial::scalar_mul::straus::Straus::optional_multiscalar_mul::<I, J>(scalars, points)
}
Expand All @@ -226,12 +256,14 @@ where
/// Perform constant-time, variable-base scalar multiplication.
pub fn variable_base_mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint {
match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
BackendKind::Avx2 => vector::scalar_mul::variable_base::spec_avx2::mul(point, scalar),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
BackendKind::Avx512 => {
vector::scalar_mul::variable_base::spec_avx512ifma_avx512vl::mul(point, scalar)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
BackendKind::Neon => vector::scalar_mul::variable_base::spec_neon::mul(point, scalar),
BackendKind::Serial => serial::scalar_mul::variable_base::mul(point, scalar),
}
}
Expand All @@ -240,12 +272,14 @@ pub fn variable_base_mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint
#[allow(non_snake_case)]
pub fn vartime_double_base_mul(a: &Scalar, A: &EdwardsPoint, b: &Scalar) -> EdwardsPoint {
match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
BackendKind::Avx2 => vector::scalar_mul::vartime_double_base::spec_avx2::mul(a, A, b),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
BackendKind::Avx512 => {
vector::scalar_mul::vartime_double_base::spec_avx512ifma_avx512vl::mul(a, A, b)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
BackendKind::Neon => vector::scalar_mul::vartime_double_base::spec_neon::mul(a, A, b),
BackendKind::Serial => serial::scalar_mul::vartime_double_base::mul(a, A, b),
}
}
7 changes: 6 additions & 1 deletion curve25519-dalek/src/backend/vector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
#![doc = include_str!("../../../docs/parallel-formulas.md")]

#[allow(missing_docs)]
#[cfg(target_arch = "x86_64")]
pub mod packed_simd;

#[cfg(target_arch = "x86_64")]
pub mod avx2;

#[cfg(nightly)]
#[cfg(all(nightly, target_arch = "x86_64"))]
pub mod ifma;

#[cfg(all(nightly, any(target_arch = "arm", target_arch = "aarch64")))]
pub mod neon;

pub mod scalar_mul;
Loading
Loading