Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Scalar constant time canonical check #384

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion curve25519-dalek/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ required-features = ["alloc", "rand_core"]
cfg-if = "1"
rand_core = { version = "0.6.4", default-features = false, optional = true }
digest = { version = "0.10", default-features = false, optional = true }
subtle = { version = "2.3.0", default-features = false }
subtle = { version = "2.4", default-features = false }
serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] }
zeroize = { version = "1", default-features = false, optional = true }

Expand Down
25 changes: 22 additions & 3 deletions curve25519-dalek/benches/dalek_benchmarks.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(non_snake_case)]

use rand::{rngs::OsRng, thread_rng};
use rand::{thread_rng, Rng};

use criterion::{
criterion_main, measurement::Measurement, BatchSize, BenchmarkGroup, BenchmarkId, Criterion,
Expand Down Expand Up @@ -249,7 +249,7 @@ mod ristretto_benches {
BenchmarkId::new("Batch Ristretto double-and-encode", *batch_size),
&batch_size,
|b, &&size| {
let mut rng = OsRng;
let mut rng = thread_rng();
let points: Vec<RistrettoPoint> = (0..size)
.map(|_| RistrettoPoint::random(&mut rng))
.collect();
Expand Down Expand Up @@ -336,7 +336,7 @@ mod scalar_benches {
BenchmarkId::new("Batch scalar inversion", *batch_size),
&batch_size,
|b, &&size| {
let mut rng = OsRng;
let mut rng = thread_rng();
let scalars: Vec<Scalar> =
(0..size).map(|_| Scalar::random(&mut rng)).collect();
b.iter(|| {
Expand All @@ -347,13 +347,32 @@ mod scalar_benches {
);
}
}
fn scalar_from_canonical_bytes<M: Measurement>(c: &mut BenchmarkGroup<M>) {
let bytes = [0xFF; 32];
c.bench_function("Max scalar from_canonical_bytes", move |b| {
b.iter(|| Scalar::from_canonical_bytes(bytes))
});
let bytes = [0u8; 32];
c.bench_function("Zero scalar from_canonical_bytes", move |b| {
b.iter(|| Scalar::from_canonical_bytes(bytes))
});
c.bench_function("Rand scalar from_canonical_bytes", |bench| {
let mut rng = thread_rng();
bench.iter_batched(
|| rng.gen(),
Scalar::from_canonical_bytes,
BatchSize::SmallInput,
);
});
}

pub(crate) fn scalar_benches() {
let mut c = Criterion::default();
let mut g = c.benchmark_group("scalar benches");

scalar_arith(&mut g);
batch_scalar_inversion(&mut g);
scalar_from_canonical_bytes(&mut g);
}
}

Expand Down
64 changes: 48 additions & 16 deletions curve25519-dalek/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,8 @@ use digest::generic_array::typenum::U64;
#[cfg(feature = "digest")]
use digest::Digest;

use subtle::Choice;
use subtle::ConditionallySelectable;
use subtle::ConstantTimeEq;
use subtle::CtOption;
use subtle::{Choice, CtOption};
use subtle::{ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater};

#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
Expand Down Expand Up @@ -253,9 +251,8 @@ impl Scalar {
/// if `bytes` is a canonical byte representation modulo the group order \\( \ell \\);
/// - `None` if `bytes` is not a canonical byte representation.
pub fn from_canonical_bytes(bytes: [u8; 32]) -> CtOption<Scalar> {
let high_bit_unset = (bytes[31] >> 7).ct_eq(&0);
let candidate = Scalar { bytes };
CtOption::new(candidate, high_bit_unset & candidate.is_canonical())
CtOption::new(candidate, candidate.is_canonical())
}

/// Construct a `Scalar` from the low 255 bits of a 256-bit integer. This breaks the invariant
Expand Down Expand Up @@ -1125,7 +1122,15 @@ impl Scalar {
/// Check whether this `Scalar` is the canonical representative mod \\(\ell\\). This is not
/// public because any `Scalar` that is publicly observed is reduced, by scalar invariant #2.
fn is_canonical(&self) -> Choice {
self.ct_eq(&self.reduce())
let mut over = Choice::from(0);
let mut under = Choice::from(0);
for (this, l) in self.unpack().0.iter().zip(&constants::L.0).rev() {
let gt = this.ct_gt(l);
let eq = this.ct_eq(l);
under |= (!gt & !eq) & !over;
over |= gt;
}
under
}
}

Expand Down Expand Up @@ -1650,15 +1655,42 @@ pub(crate) mod test {
0, 0, 128,
];

assert!(bool::from(
Scalar::from_canonical_bytes(canonical_bytes).is_some()
));
assert!(bool::from(
Scalar::from_canonical_bytes(non_canonical_bytes_because_unreduced).is_none()
));
assert!(bool::from(
Scalar::from_canonical_bytes(non_canonical_bytes_because_highbit).is_none()
));
let canonical_l_minus_one = [
237, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15,
];
let canonical_zero = [0u8; 32];
let canonical_255_minus_1 = [
132, 52, 71, 117, 71, 74, 127, 151, 35, 182, 58, 139, 233, 42, 231, 109, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 15,
];
let non_canonical_l = [
237, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16,
];
let non_canonical_l_plus_one = [
237, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17,
];
let non_canonical_full = [0xFF; 32];
let non_canonical_255_minus_1 = {
let mut non_canonical_255_minus_1 = [0xFF; 32];
non_canonical_255_minus_1[31] = 0b0111_1111;
non_canonical_255_minus_1
};

let from_canonical_option = |b| Option::<Scalar>::from(Scalar::from_canonical_bytes(b));

assert!(from_canonical_option(canonical_bytes).is_some());
assert!(from_canonical_option(canonical_l_minus_one).is_some());
assert!(from_canonical_option(canonical_zero).is_some());
assert!(from_canonical_option(canonical_255_minus_1).is_some());
assert!(from_canonical_option(non_canonical_bytes_because_unreduced).is_none());
assert!(from_canonical_option(non_canonical_bytes_because_highbit).is_none());
assert!(from_canonical_option(non_canonical_l).is_none());
assert!(from_canonical_option(non_canonical_l_plus_one).is_none());
assert!(from_canonical_option(non_canonical_full).is_none());
assert!(from_canonical_option(non_canonical_255_minus_1).is_none());
}

#[test]
Expand Down