From 8c1d140669d967c8fe62def160832d88e0217f15 Mon Sep 17 00:00:00 2001 From: aPere3 Date: Thu, 14 Apr 2022 17:35:28 +0200 Subject: [PATCH] feat(concrete_csprng): refactor the csprng to prepare core backend split BREAKING_CHANGE: this commit completely breaks the previous API --- concrete-csprng/Cargo.toml | 14 +- concrete-csprng/README.md | 7 + concrete-csprng/benches/benchmark.rs | 43 +- concrete-csprng/src/counter/mod.rs | 271 ------- concrete-csprng/src/counter/state.rs | 669 ------------------ concrete-csprng/src/counter/test.rs | 187 ----- concrete-csprng/src/counter/test_aes.rs | 20 - .../src/generators/aes_ctr/block_cipher.rs | 20 + .../src/generators/aes_ctr/generic.rs | 377 ++++++++++ .../src/generators/aes_ctr/index.rs | 365 ++++++++++ concrete-csprng/src/generators/aes_ctr/mod.rs | 223 ++++++ .../src/generators/aes_ctr/parallel.rs | 222 ++++++ .../src/generators/aes_ctr/states.rs | 176 +++++ .../implem/aesni/block_cipher.rs} | 83 +-- .../src/generators/implem/aesni/generator.rs | 110 +++ .../src/generators/implem/aesni/mod.rs | 16 + .../src/generators/implem/aesni/parallel.rs | 94 +++ concrete-csprng/src/generators/implem/mod.rs | 9 + .../generators/implem/soft/block_cipher.rs | 114 +++ .../src/generators/implem/soft/generator.rs | 110 +++ .../src/generators/implem/soft/mod.rs | 12 + .../src/generators/implem/soft/parallel.rs | 93 +++ concrete-csprng/src/generators/mod.rs | 196 +++++ concrete-csprng/src/lib.rs | 407 +++-------- .../src/{generate_random.rs => main.rs} | 10 +- concrete-csprng/src/seeders/implem/linux.rs | 59 ++ concrete-csprng/src/seeders/implem/mod.rs | 9 + concrete-csprng/src/seeders/implem/rdseed.rs | 46 ++ concrete-csprng/src/seeders/mod.rs | 43 ++ concrete-csprng/src/software.rs | 193 ----- 30 files changed, 2468 insertions(+), 1730 deletions(-) delete mode 100644 concrete-csprng/src/counter/mod.rs delete mode 100644 concrete-csprng/src/counter/state.rs delete mode 100644 concrete-csprng/src/counter/test.rs delete mode 100644 concrete-csprng/src/counter/test_aes.rs create mode 100644 concrete-csprng/src/generators/aes_ctr/block_cipher.rs create mode 100644 concrete-csprng/src/generators/aes_ctr/generic.rs create mode 100644 concrete-csprng/src/generators/aes_ctr/index.rs create mode 100644 concrete-csprng/src/generators/aes_ctr/mod.rs create mode 100644 concrete-csprng/src/generators/aes_ctr/parallel.rs create mode 100644 concrete-csprng/src/generators/aes_ctr/states.rs rename concrete-csprng/src/{aesni.rs => generators/implem/aesni/block_cipher.rs} (76%) create mode 100644 concrete-csprng/src/generators/implem/aesni/generator.rs create mode 100644 concrete-csprng/src/generators/implem/aesni/mod.rs create mode 100644 concrete-csprng/src/generators/implem/aesni/parallel.rs create mode 100644 concrete-csprng/src/generators/implem/mod.rs create mode 100644 concrete-csprng/src/generators/implem/soft/block_cipher.rs create mode 100644 concrete-csprng/src/generators/implem/soft/generator.rs create mode 100644 concrete-csprng/src/generators/implem/soft/mod.rs create mode 100644 concrete-csprng/src/generators/implem/soft/parallel.rs create mode 100644 concrete-csprng/src/generators/mod.rs rename concrete-csprng/src/{generate_random.rs => main.rs} (54%) create mode 100644 concrete-csprng/src/seeders/implem/linux.rs create mode 100644 concrete-csprng/src/seeders/implem/mod.rs create mode 100644 concrete-csprng/src/seeders/implem/rdseed.rs create mode 100644 concrete-csprng/src/seeders/mod.rs delete mode 100644 concrete-csprng/src/software.rs diff --git a/concrete-csprng/Cargo.toml b/concrete-csprng/Cargo.toml index 0101c26278..d4fbb5d789 100644 --- a/concrete-csprng/Cargo.toml +++ b/concrete-csprng/Cargo.toml @@ -20,13 +20,19 @@ rand = "0.8.3" criterion = "0.3" [features] -slow = [] -multithread = ["rayon"] +parallel = ["rayon"] +seeder_x86_64_rdseed = [] +seeder_linux = [] +generator_x86_64_aesni = [] +generator_soft = [] [[bench]] name = "benchmark" +path = "benches/benchmark.rs" harness = false +required-features = ["seeder_x86_64_rdseed", "generator_x86_64_aesni"] [[bin]] -name = "generate_random" -path = "src/generate_random.rs" +name = "generate" +path = "src/main.rs" +required-features = ["seeder_x86_64_rdseed", "generator_x86_64_aesni"] diff --git a/concrete-csprng/README.md b/concrete-csprng/README.md index 1fabcf063d..a345e4284b 100644 --- a/concrete-csprng/README.md +++ b/concrete-csprng/README.md @@ -9,6 +9,13 @@ The implementation is based on the AES blockcipher used in CTR mode, as describe The current implementation uses special instructions existing on modern *intel* cpus. We may add a generic implementation in the future. +## Running the benchmarks + +To execute the benchmarks on an x86_64 platform: +```shell +RUSTFLAGS="-Ctarget-cpu=native" cargo bench --features=seeder_rdseed,generator_aesni +``` + ## License This software is distributed under the BSD-3-Clause-Clear license. If you have any questions, diff --git a/concrete-csprng/benches/benchmark.rs b/concrete-csprng/benches/benchmark.rs index 3e13d9242e..6de6bf5523 100644 --- a/concrete-csprng/benches/benchmark.rs +++ b/concrete-csprng/benches/benchmark.rs @@ -1,34 +1,53 @@ -use concrete_csprng::RandomGenerator; -use criterion::{criterion_group, criterion_main, Criterion}; +use concrete_csprng::generators::{ + AesniRandomGenerator, BytesPerChild, ChildrenCount, RandomGenerator, +}; +use concrete_csprng::seeders::{RdseedSeeder, Seeder}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; const N_GEN: usize = 1_000_000; -fn unbounded_benchmark(c: &mut Criterion) { - let mut generator = RandomGenerator::new_hardware(None).unwrap(); - c.bench_function("unbounded", |b| { +fn parent_generate(c: &mut Criterion) { + let mut seeder = RdseedSeeder; + let mut generator = AesniRandomGenerator::new(seeder.seed()); + c.bench_function("parent_generate", |b| { b.iter(|| { (0..N_GEN).for_each(|_| { - generator.generate_next(); + generator.next(); }) }) }); } -fn bounded_benchmark(c: &mut Criterion) { - let mut generator = RandomGenerator::new_hardware(None).unwrap(); +fn child_generate(c: &mut Criterion) { + let mut seeder = RdseedSeeder; + let mut generator = AesniRandomGenerator::new(seeder.seed()); let mut generator = generator - .try_fork(1, N_GEN * 10_000) + .try_fork(ChildrenCount(1), BytesPerChild(N_GEN * 10_000)) .unwrap() .next() .unwrap(); - c.bench_function("bounded", |b| { + c.bench_function("child_generate", |b| { b.iter(|| { (0..N_GEN).for_each(|_| { - generator.generate_next(); + generator.next(); }) }) }); } -criterion_group!(benches, unbounded_benchmark, bounded_benchmark); +fn fork(c: &mut Criterion) { + let mut seeder = RdseedSeeder; + let mut generator = AesniRandomGenerator::new(seeder.seed()); + c.bench_function("fork", |b| { + b.iter(|| { + black_box( + generator + .try_fork(ChildrenCount(2048), BytesPerChild(2048)) + .unwrap(), + ) + }) + }); +} + +criterion_group!(benches, parent_generate, child_generate, fork); criterion_main!(benches); diff --git a/concrete-csprng/src/counter/mod.rs b/concrete-csprng/src/counter/mod.rs deleted file mode 100644 index dff4ad190e..0000000000 --- a/concrete-csprng/src/counter/mod.rs +++ /dev/null @@ -1,271 +0,0 @@ -use crate::{aesni, software}; -#[cfg(feature = "multithread")] -use rayon::{iter::IndexedParallelIterator, prelude::*}; -use std::error::Error; -use std::fmt::{Display, Formatter}; - -#[cfg(test)] -mod test; - -#[cfg(all( - test, - target_arch = "x86_64", - target_feature = "aes", - target_feature = "sse2", - target_feature = "rdseed" -))] -mod test_aes; - -mod state; -pub use state::*; - -/// Represents a key used in the AES ciphertext. -#[derive(Clone, Copy)] -pub struct AesKey(pub u128); - -/// A trait for batched generators, i.e. generators that creates 128 bytes of random values at a -/// time. -pub trait AesBatchedGenerator: Clone { - /// Instantiate a new generator from a secret key. - fn new(key: Option) -> Self; - /// Generates the batch corresponding to the given counter. - fn generate_batch(&mut self, index: AesIndex) -> [u8; 128]; -} - -/// A generator that uses the software implementation. -pub type SoftAesCtrGenerator = AesCtrGenerator; - -/// A generator that uses the hardware implementation. -pub type HardAesCtrGenerator = AesCtrGenerator; - -/// An error occuring during a generator fork. -#[derive(Debug)] -pub enum ForkError { - ForkTooLarge, - ZeroChildrenCount, - ZeroBytesPerChild, -} - -impl Display for ForkError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - ForkError::ForkTooLarge => { - write!( - f, - "The children generators would output bytes after the parent bound. " - ) - } - ForkError::ZeroChildrenCount => { - write!( - f, - "The number of children in the fork must be greater than zero." - ) - } - ForkError::ZeroBytesPerChild => { - write!( - f, - "The number of bytes per child must be greater than zero." - ) - } - } - } -} -impl Error for ForkError {} - -/// A csprng which operates in batch mode. -#[derive(Clone)] -pub struct AesCtrGenerator { - generator: G, - state: State, - bound: TableIndex, - last: TableIndex, - buffer: [u8; 128], -} - -impl AesCtrGenerator { - /// Generates a new csprng. - /// - /// If not given, the key is automatically selected, and the state is set to zero. - /// - /// Note : - /// ------ - /// - /// The state given in input, points to the first byte that will be outputted by the generator. - /// The bound points to the first byte that can not be outputted by the generator. - pub fn new( - key: Option, - start_index: Option, - bound_index: Option, - ) -> AesCtrGenerator { - AesCtrGenerator::from_generator( - G::new(key), - start_index.unwrap_or(TableIndex::SECOND), - bound_index.unwrap_or(TableIndex::LAST), - ) - } - - /// Generates a csprng from an existing generator. - pub fn from_generator( - generator: G, - start_index: TableIndex, - bound_index: TableIndex, - ) -> AesCtrGenerator { - assert!(start_index < bound_index); - let last = bound_index.decremented(); - let buffer = [0u8; 128]; - AesCtrGenerator { - generator, - state: State::new(start_index), - bound: bound_index, - last, - buffer, - } - } - - /// Returns the table index related to the last yielded byte. - pub fn last_table_index(&self) -> TableIndex { - self.state.table_index() - } - - /// Returns the bound of the generator if any. - /// - /// The bound is the table index of the first byte that can not be outputted by the generator. - pub fn get_bound(&self) -> TableIndex { - self.bound - } - - /// Returns whether the generator is bounded or not. - pub fn is_bounded(&self) -> bool { - self.bound != TableIndex::LAST - } - - /// Computes the number of bytes that can still be outputted by the generator. - /// - /// Note : - /// ------ - /// - /// Note that `ByteCount` uses the `u128` datatype to store the byte count. Unfortunately, the - /// number of remaining bytes is in ⟦0;2¹³² -1⟧. When the number is greater than 2¹²⁸ - 1, - /// we saturate the count at 2¹²⁸ - 1. - pub fn remaining_bytes(&self) -> ByteCount { - TableIndex::distance(&self.last, &self.state.table_index()).unwrap() - } - - /// Yields the next random byte. - pub fn generate_next(&mut self) -> u8 { - assert!(self.state.table_index() < self.last,); - match self.state.increment() { - ShiftAction::YieldByte(BufferPointer(ptr)) => self.buffer[ptr], - ShiftAction::RefreshBatchAndYieldByte(aes_index, BufferPointer(ptr)) => { - self.buffer = self.generator.generate_batch(aes_index); - self.buffer[ptr] - } - } - } - - /// Tries to fork the current generator into `n_child` generators each able to yield - /// `child_bytes` random bytes. - pub fn try_fork( - &mut self, - n_child: ChildrenCount, - child_bytes: BytesPerChild, - ) -> Result>, ForkError> { - if n_child.0 == 0 { - return Err(ForkError::ZeroChildrenCount); - } - if child_bytes.0 == 0 { - return Err(ForkError::ZeroBytesPerChild); - } - if !self.is_fork_in_bound(n_child, child_bytes) { - return Err(ForkError::ForkTooLarge); - } - - let generator = self.generator.clone(); - // The state currently stored in the parent generator points to the table index of the last - // generated byte. The first index to be generated is the next one : - let first_index = self.state.table_index().incremented(); - let output = (0..n_child.0).map(move |i| { - // The first index to be outputted by the child is the `first_index` shifted by the - // proper amount of `child_bytes`. - let child_first_index = first_index.increased(child_bytes.0 * i); - // The bound of the child is the first index of its next sibling. - let child_bound_index = first_index.increased(child_bytes.0 * (i + 1)); - AesCtrGenerator::from_generator(generator.clone(), child_first_index, child_bound_index) - }); - // The parent next index is the bound of the last child. - let next_index = first_index.increased(child_bytes.0 * n_child.0); - self.state = State::new(next_index); - - Ok(output) - } - - /// Tries to fork the current generator into `n_child` generators each able to yield - /// `child_bytes` random bytes as a parallel iterator. - /// - /// # Notes - /// - /// This method necessitate the "multithread" feature. - #[cfg(feature = "multithread")] - pub fn par_try_fork( - &mut self, - n_child: ChildrenCount, - child_bytes: BytesPerChild, - ) -> Result>, ForkError> - where - G: Send + Sync, - { - if n_child.0 == 0 { - return Err(ForkError::ZeroChildrenCount); - } - if child_bytes.0 == 0 { - return Err(ForkError::ZeroBytesPerChild); - } - if !self.is_fork_in_bound(n_child, child_bytes) { - return Err(ForkError::ForkTooLarge); - } - - let generator = self.generator.clone(); - // The state currently stored in the parent generator points to the table index of the last - // generated byte. The first index to be generated is the next one : - let first_index = self.state.table_index().incremented(); - let output = (0..n_child.0).into_par_iter().map(move |i| { - // The first index to be outputted by the child is the `first_index` shifted by the - // proper amount of `child_bytes`. - let child_first_index = first_index.increased(child_bytes.0 * i); - // The bound of the child is the first index of its next sibling. - let child_bound_index = first_index.increased(child_bytes.0 * (i + 1)); - AesCtrGenerator::from_generator(generator.clone(), child_first_index, child_bound_index) - }); - // The parent next index is the bound of the last child. - let next_index = first_index.increased(child_bytes.0 * n_child.0); - self.state = State::new(next_index); - - Ok(output) - } - - fn is_fork_in_bound(&self, n_child: ChildrenCount, child_bytes: BytesPerChild) -> bool { - let mut end = self.state.table_index(); - end.increase(n_child.0 * child_bytes.0); - end < self.bound - } -} - -impl Iterator for AesCtrGenerator { - type Item = u8; - - fn next(&mut self) -> Option { - if self.state.table_index() < self.last { - None - } else { - Some(self.generate_next()) - } - } -} - -/// The number of children created when a generator is forked. -#[derive(Debug, Copy, Clone)] -pub struct ChildrenCount(pub usize); - -/// The number of bytes each children can generate, when a generator is forked. -#[derive(Debug, Copy, Clone)] -pub struct BytesPerChild(pub usize); diff --git a/concrete-csprng/src/counter/state.rs b/concrete-csprng/src/counter/state.rs deleted file mode 100644 index 37ec8c4d0f..0000000000 --- a/concrete-csprng/src/counter/state.rs +++ /dev/null @@ -1,669 +0,0 @@ -//! A module to manipulate aes counter states. -//! -//! Coarse-grained pseudo-random table lookup -//! ========================================= -//! -//! To generate random values, we use the AES block cipher in counter mode. If we denote f the aes -//! encryption function, we have: -//! ```ascii -//! f: ⟦0;2¹²⁸ -1⟧ X ⟦0;2¹²⁸ -1⟧ ↦ ⟦0;2¹²⁸ -1⟧ -//! f(secret_key, input) ↦ output -//! ``` -//! -//! If we fix the secret key to a value k, we have a pure function fₖ from ⟦0;2¹²⁸ -1⟧ to -//! ⟦0;2¹²⁸-1⟧, transforming the state of the counter into a pseudo random value. Essentially, this -//! fₖ function can be considered as a lookup function into a table of 2¹²⁸ pseudo-random values: -//! ```ascii -//! ╭──────────────┬──────────────┬─────┬──────────────╮ -//! │ fₖ(0) │ fₖ(1) │ │ fₖ(2¹²⁸ -1) │ -//! ╔═══════↧══════╦═══════↧══════╦═════╦═══════↧══════╗ -//! ║┏━━━━━━━━━━━━┓║┏━━━━━━━━━━━━┓║ ║┏━━━━━━━━━━━━┓║ -//! ║┃ u128 ┃║┃ u128 ┃║ ... ║┃ u128 ┃║ -//! ║┗━━━━━━━━━━━━┛║┗━━━━━━━━━━━━┛║ ║┗━━━━━━━━━━━━┛║ -//! ╚══════════════╩══════════════╩═════╩══════════════╝ -//! ``` -//! -//! We call this input to the fₖ function, an _aes index_ of the pseudo-random table. The -//! [`AesIndex`] structure defined in this module represents such an index in the code. -//! -//! Fine-grained pseudo-random table lookup -//! ======================================= -//! -//! Unfortunately this is not enough to handle our situation, since we want to deliver the -//! pseudo-random bytes one by one. Fortunately, each `u128` value yielded by fₖ can be seen as a -//! table of 16 `u8`: -//! ```ascii -//! ╭──────────────┬──────────────┬─────┬──────────────╮ -//! │ fₖ(0) │ fₖ(1) │ │ fₖ(2¹²⁸ -1) │ -//! ╔═══════↧══════╦═══════↧══════╦═════╦═══════↧══════╗ -//! ║┏━━━━━━━━━━━━┓║┏━━━━━━━━━━━━┓║ ║┏━━━━━━━━━━━━┓║ -//! ║┃ u128 ┃║┃ u128 ┃║ ║┃ u128 ┃║ -//! ║┣━━┯━━┯━━━┯━━┫║┣━━┯━━┯━━━┯━━┫║ ... ║┣━━┯━━┯━━━┯━━┫║ -//! ║┃u8│u8│...│u8┃║┃u8│u8│...│u8┃║ ║┃u8│u8│...│u8┃║ -//! ║┗━━┷━━┷━━━┷━━┛║┗━━┷━━┷━━━┷━━┛║ ║┗━━┷━━┷━━━┷━━┛║ -//! ╚══════════════╩══════════════╩═════╩══════════════╝ -//! ``` -//! -//! We introduce a second function to index into this table of small integers: -//! ```ascii -//! g: ⟦0;2¹²⁸ -1⟧ X ⟦0;15⟧ ↦ ⟦0;2⁸ -1⟧ -//! g(big_int, index) ↦ byte -//! ``` -//! -//! If we fix the `u128` value to a value e, we have a pure function gₑ from ⟦0;15⟧ to ⟦0;2⁸ -1⟧ -//! transforming an index into a pseudo-random byte: -//! ```ascii -//! ┏━━━━━━━━┯━━━━━━━━┯━━━┯━━━━━━━━┓ -//! ┃ u8 │ u8 │...│ u8 ┃ -//! ┗━━━━━━━━┷━━━━━━━━┷━━━┷━━━━━━━━┛ -//! │ gₑ(0) │ gₑ(1) │ │ gₑ(15) │ -//! ╰────────┴─────-──┴───┴────────╯ -//! ``` -//! -//! We call this input to the gₑ function, a _byte index_ of the pseudo-random table. The -//! [`ByteIndex`] structure defined in this module represents such an index in the code. -//! -//! By using both the g and the fₖ functions, we can define a new function l which allows to index -//! any byte of the pseudo-random table: -//! ```ascii -//! l: ⟦0;2¹²⁸ -1⟧ X ⟦0;15⟧ ↦ ⟦0;2⁸ -1⟧ -//! l(aes_index, byte_index) ↦ g(fₖ(aes_index), byte_index) -//! ``` -//! -//! In this sense, any member of ⟦0;2¹²⁸ -1⟧ X ⟦0;15⟧ uniquely defines a byte in this pseudo-random -//! table: -//! ```ascii -//! ╭──────────────────────────────────────────────────╮ -//! │ e = fₖ(a) │ -//! ╔══════════════╦═══════↧══════╦═════╦══════════════╗ -//! ║┏━━━━━━━━━━━━┓║┏━━━━━━━━━━━━┓║ ║┏━━━━━━━━━━━━┓║ -//! ║┃ u128 ┃║┃ u128 ┃║ ║┃ u128 ┃║ -//! ║┣━━┯━━┯━━━┯━━┫║┣━━┯━━┯━━━┯━━┫║ ... ║┣━━┯━━┯━━━┯━━┫║ -//! ║┃u8│u8│...│u8┃║┃u8│u8│...│u8┃║ ║┃u8│u8│...│u8┃║ -//! ║┗━━┷━━┷━━━┷━━┛║┗━━┷↥━┷━━━┷━━┛║ ║┗━━┷━━┷━━━┷━━┛║ -//! ║ ║│ gₑ(b) │║ ║ ║ -//! ║ ║╰───-────────╯║ ║ ║ -//! ╚══════════════╩══════════════╩═════╩══════════════╝ -//! ``` -//! -//! We call this input to the l function, a _table index_ of the pseudo-random table. The -//! [`TableIndex`] structure defined in this module represents such an index in the code. -//! -//! Prngs current table index -//! ========================= -//! -//! When created, a prng is given an initial _table index_, denoted (a₀, b₀), which identifies the -//! first byte of the table to be outputted by the prng. Then, each time the prng is queried for a -//! new value, the byte corresponding to the current _table index_ is returned, and the current -//! _table index_ is incremented: -//! ```ascii -//! ╭─────────────────────────────────────────╮ ╭─────────────────────────────────────────╮ -//! │ e = fₖ(a₀) │ │ e = fₖ(a₁) │ -//! ╔═════↧═════╦═══════════╦═════╦═══════════╗ ╔═══════════╦═════↧═════╦═════╦═══════════╗ -//! ║┏━┯━┯━━━┯━┓║┏━┯━┯━━━┯━┓║ ... ║┏━┯━┯━━━┯━┓║ ║┏━┯━┯━━━┯━┓║┏━┯━┯━━━┯━┓║ ... ║┏━┯━┯━━━┯━┓║ -//! ║┃ │ │...│ ┃║┃ │ │...│ ┃║ ║┃ │ │...│ ┃║ ║┃ │ │...│ ┃║┃ │ │...│ ┃║ ║┃ │ │...│ ┃║ -//! ║┗━┷━┷━━━┷↥┛║┗━┷━┷━━━┷━┛║ ║┗━┷━┷━━━┷━┛║ → ║┗━┷━┷━━━┷━┛║┗↥┷━┷━━━┷━┛║ ║┗━┷━┷━━━┷━┛║ -//! ║│ gₑ(b₀) │║ ║ ║ ║ ║ ║│ gₑ(b₁) │║ ║ ║ -//! ║╰─────────╯║ ║ ║ ║ ║ ║╰─────────╯║ ║ ║ -//! ╚═══════════╩═══════════╩═════╩═══════════╝ ╚═══════════╩═══════════╩═════╩═══════════╝ -//! ``` -//! -//! Prng bound -//! ========== -//! -//! When created, a prng is also given a _bound_ (aₘ, bₘ) , that is a table index which it is not -//! allowed to exceed: -//! ```ascii -//! ╭─────────────────────────────────────────╮ -//! │ e = fₖ(a₀) │ -//! ╔═════↧═════╦═══════════╦═════╦═══════════╗ -//! ║┏━┯━┯━━━┯━┓║┏━┯━┯━━━┯━┓║ ... ║┏━┯━┯━━━┯━┓║ -//! ║┃ │ │...│ ┃║┃ │╳│...│╳┃║ ║┃╳│╳│...│╳┃║ -//! ║┗━┷━┷━━━┷↥┛║┗━┷━┷━━━┷━┛║ ║┗━┷━┷━━━┷━┛║ The current byte can be returned. -//! ║│ gₑ(b₀) │║ ║ ║ ║ -//! ║╰─────────╯║ ║ ║ ║ -//! ╚═══════════╩═══════════╩═════╩═══════════╝ -//! -//! ╭─────────────────────────────────────────╮ -//! │ e = fₖ(aₘ) │ -//! ╔═══════════╦═════↧═════╦═════╦═══════════╗ -//! ║┏━┯━┯━━━┯━┓║┏━┯━┯━━━┯━┓║ ... ║┏━┯━┯━━━┯━┓║ -//! ║┃ │ │...│ ┃║┃ │╳│...│╳┃║ ║┃╳│╳│...│╳┃║ The table index reached the bound, -//! ║┗━┷━┷━━━┷━┛║┗━┷↥┷━━━┷━┛║ ║┗━┷━┷━━━┷━┛║ the current byte can not be -//! ║ ║│ gₑ(bₘ) │║ ║ ║ returned. -//! ║ ║╰─────────╯║ ║ ║ -//! ╚═══════════╩═══════════╩═════╩═══════════╝ -//! ``` -//! -//! Buffering -//! ========= -//! -//! Calling the aes function every time we need to yield a single byte would be a huge waste of -//! resources. In practice, we call aes 8 times in a row, for 8 successive values of aes index, and -//! store the results in a buffer. For platforms which have a dedicated aes chip, this allows to -//! fill the unit pipeline and reduces the amortized cost of the aes function. -//! -//! Together with the current table index of the prng, we also store a pointer p (initialized at -//! p₀=b₀) to the current byte in the buffer. If we denote v the lookup function we have : -//! ```ascii -//! ╭───────────────────────────────────────────────╮ -//! │ e = fₖ(a₀) │ Buffer(length=128) -//! ╔═════╦═══════════╦═════↧═════╦═══════════╦═════╗ ┏━┯━┯━┯━┯━┯━┯━┯━┯━━━┯━┓ -//! ║ ... ║┏━┯━┯━━━┯━┓║┏━┯━┯━━━┯━┓║┏━┯━┯━━━┯━┓║ ... ║ ┃▓│▓│▓│▓│▓│▓│▓│▓│...│▓┃ -//! ║ ║┃ │ │...│ ┃║┃▓│▓│...│▓┃║┃▓│▓│...│▓┃║ ║ ┗━┷↥┷━┷━┷━┷━┷━┷━┷━━━┷━┛ -//! ║ ║┗━┷━┷━━━┷━┛║┗━┷↥┷━━━┷━┛║┗━┷━┷━━━┷━┛║ ║ │ v(p₀) │ -//! ║ ║ ║│ gₑ(b₀) │║ ║ ║ ╰─────────────────────╯ -//! ║ ║ ║╰─────────╯║ ║ ║ -//! ╚═════╩═══════════╩═══════════╩═══════════╩═════╝ -//! ``` -//! -//! We call this input to the v function, a _buffer pointer_. The [`BufferPointer`] structure -//! defined in this module represents such a pointer in the code. -//! -//! When the table index is incremented, the buffer pointer is incremented alongside: -//! ```ascii -//! ╭───────────────────────────────────────────────╮ -//! │ e = fₖ(a) │ Buffer(length=128) -//! ╔═════╦═══════════╦═════↧═════╦═══════════╦═════╗ ┏━┯━┯━┯━┯━┯━┯━┯━┯━━━┯━┓ -//! ║ ... ║┏━┯━┯━━━┯━┓║┏━┯━┯━━━┯━┓║┏━┯━┯━━━┯━┓║ ... ║ ┃▓│▓│▓│▓│▓│▓│▓│▓│...│▓┃ -//! ║ ║┃ │ │...│ ┃║┃▓│▓│...│▓┃║┃▓│▓│...│▓┃║ ║ ┗━┷━┷↥┷━┷━┷━┷━┷━┷━━━┷━┛ -//! ║ ║┗━┷━┷━━━┷━┛║┗━┷━┷↥━━┷━┛║┗━┷━┷━━━┷━┛║ ║ │ v(p) │ -//! ║ ║ ║│ gₑ(b) │║ ║ ║ ╰─────────────────────╯ -//! ║ ║ ║╰─────────╯║ ║ ║ -//! ╚═════╩═══════════╩═══════════╩═══════════╩═════╝ -//! ``` -//! -//! When the buffer pointer is incremented it is checked against the size of the buffer, and if -//! necessary, a new batch of aes index values. -use std::cmp::Ordering; - -/// A structure representing an [aes index](#coarse-grained-pseudo-random-table-lookup). -#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] -pub struct AesIndex(pub u128); - -/// A structuure representing a [byte index](#fine-grained-pseudo-random-table-lookup). -#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] -pub struct ByteIndex(pub usize); - -/// A structure representing the number of bytes between two table indices. -#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] -pub struct ByteCount(pub u128); - -/// A structure representing a [table index](#fine-grained-pseudo-random-table-lookup) -#[derive(Clone, Copy, Debug)] -pub struct TableIndex { - pub(crate) aes_index: AesIndex, - pub(crate) byte_index: ByteIndex, -} - -impl TableIndex { - /// The first table index. - pub const FIRST: TableIndex = TableIndex { - aes_index: AesIndex(0), - byte_index: ByteIndex(0), - }; - - /// The second table index. - pub const SECOND: TableIndex = TableIndex { - aes_index: AesIndex(0), - byte_index: ByteIndex(1), - }; - - /// The last table index. - pub const LAST: TableIndex = TableIndex { - aes_index: AesIndex(u128::MAX), - byte_index: ByteIndex(15), - }; - - /// Creates a table index from an aes index and a byte index. - pub fn new(aes_index: AesIndex, byte_index: ByteIndex) -> Self { - assert!(byte_index.0 <= 15); - TableIndex { - aes_index, - byte_index, - } - } - - /// Shifts the table index forward of `shift` bytes. - pub fn increase(&mut self, shift: usize) { - let total = self.byte_index.0 + shift; - self.byte_index.0 = total % 16; - self.aes_index.0 = self.aes_index.0.wrapping_add(total as u128 / 16); - } - - /// Shifts the table index backward of `shift` bytes. - pub fn decrease(&mut self, shift: usize) { - let remainder = shift % 16; - if remainder <= self.byte_index.0 { - self.aes_index.0 = self.aes_index.0.wrapping_sub((shift / 16) as u128); - self.byte_index.0 -= remainder; - } else { - self.aes_index.0 = self.aes_index.0.wrapping_sub((shift / 16) as u128 + 1); - self.byte_index.0 += 16 - remainder; - } - } - - /// Shifts the table index forward of one byte. - pub fn increment(&mut self) { - self.increase(1) - } - - /// Shifts the table index backward of one byte. - pub fn decrement(&mut self) { - self.decrease(1) - } - - /// Returns the table index shifted forward by `shift` bytes. - pub fn increased(mut self, shift: usize) -> Self { - self.increase(shift); - self - } - - /// Returns the table index shifted backward by `shift` bytes. - pub fn decreased(mut self, shift: usize) -> Self { - self.decrease(shift); - self - } - - /// Returns the table index to the next byte. - pub fn incremented(mut self) -> Self { - self.increment(); - self - } - - /// Returns the table index to the previous byte. - pub fn decremented(mut self) -> Self { - self.decrement(); - self - } - - /// Returns the distance between two table indices in bytes. - /// - /// Note: - /// ----- - /// - /// This method assumes that the `larger` input is, well, larger than the `smaller` input. If - /// this is not the case, the method returns `None`. Also, note that `ByteCount` uses the - /// `u128` datatype to store the byte count. Unfortunately, the number of bytes between two - /// table indices is in ⟦0;2¹³² -1⟧. When the distance is greater than 2¹²⁸ - 1, we saturate - /// the count at 2¹²⁸ - 1. - pub fn distance(larger: &Self, smaller: &Self) -> Option { - match std::cmp::Ord::cmp(larger, smaller) { - Ordering::Less => None, - Ordering::Equal => Some(ByteCount(0)), - Ordering::Greater => { - let mut result = larger.aes_index.0 - smaller.aes_index.0; - result = result.saturating_mul(16); - result = result.saturating_add(larger.byte_index.0 as u128); - result = result.saturating_sub(smaller.byte_index.0 as u128); - Some(ByteCount(result)) - } - } - } -} - -impl Eq for TableIndex {} - -impl PartialEq for TableIndex { - fn eq(&self, other: &Self) -> bool { - matches!(self.partial_cmp(other), Some(Ordering::Equal)) - } -} - -impl PartialOrd for TableIndex { - fn partial_cmp(&self, other: &Self) -> Option { - match self.aes_index.partial_cmp(&other.aes_index) { - Some(Ordering::Equal) => self.byte_index.partial_cmp(&other.byte_index), - other => other, - } - } -} - -impl Ord for TableIndex { - fn cmp(&self, other: &Self) -> Ordering { - self.partial_cmp(other).unwrap() - } -} - -/// A pointer to the next byte to be outputted by the generator. -#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] -pub struct BufferPointer(pub usize); - -/// A structure representing the current state of generator using batched aes-ctr approach. -#[derive(Debug, Clone, Copy)] -pub struct State { - table_index: TableIndex, - buffer_pointer: BufferPointer, -} - -/// A structure representing the action to be taken by the generator after shifting its state. -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum ShiftAction { - /// Yield the byte pointed to by the 0-th field. - YieldByte(BufferPointer), - /// Refresh the buffer starting from the 0-th field, and yield the byte pointed to by the 0-th - /// field. - RefreshBatchAndYieldByte(AesIndex, BufferPointer), -} - -impl State { - /// Creates a new state from the initial table index. - pub fn new(table_index: TableIndex) -> Self { - State { - table_index: table_index.decremented(), - buffer_pointer: BufferPointer(127), - } - } - - /// Shifts the state forward of `shift` bytes. - pub fn increase(&mut self, shift: usize) -> ShiftAction { - self.table_index.increase(shift); - let total_batch_index = self.buffer_pointer.0 + shift; - if total_batch_index > 127 { - self.buffer_pointer.0 = self.table_index.byte_index.0; - ShiftAction::RefreshBatchAndYieldByte(self.table_index.aes_index, self.buffer_pointer) - } else { - self.buffer_pointer.0 = total_batch_index; - ShiftAction::YieldByte(self.buffer_pointer) - } - } - - /// Shifts the state forward of one byte. - pub fn increment(&mut self) -> ShiftAction { - self.increase(1) - } - - /// Returns the current table index. - pub fn table_index(&self) -> TableIndex { - self.table_index - } -} - -impl Default for State { - fn default() -> Self { - State::new(TableIndex::FIRST) - } -} - -#[cfg(test)] -mod test { - use super::*; - use rand::{thread_rng, Rng}; - - const REPEATS: usize = 1_000_000; - - fn any_table_index() -> impl Iterator { - std::iter::repeat_with(|| { - TableIndex::new( - AesIndex(thread_rng().gen()), - ByteIndex(thread_rng().gen::() % 16), - ) - }) - } - - fn any_usize() -> impl Iterator { - std::iter::repeat_with(|| thread_rng().gen()) - } - - #[test] - #[should_panic] - /// Verifies that the constructor of `TableIndex` panics when the byte index is too large. - fn test_table_index_new_panic() { - TableIndex::new(AesIndex(12), ByteIndex(144)); - } - - #[test] - /// Verifies that the `TableIndex` wraps nicely with predecessor - fn test_table_index_predecessor_edge() { - assert_eq!(TableIndex::FIRST.decremented(), TableIndex::LAST); - } - - #[test] - /// Verifies that the `TableIndex` wraps nicely with successor - fn test_table_index_successor_edge() { - assert_eq!(TableIndex::LAST.incremented(), TableIndex::FIRST); - } - - #[test] - /// Check that the table index distance saturates nicely. - fn prop_table_index_distance_saturates() { - assert_eq!( - TableIndex::distance(&TableIndex::LAST, &TableIndex::FIRST) - .unwrap() - .0, - u128::MAX - ) - } - - #[test] - /// Check the property: - /// For all table indices t, - /// distance(t, t) = Some(0). - fn prop_table_index_distance_zero() { - for _ in 0..REPEATS { - let t = any_table_index().next().unwrap(); - assert_eq!(TableIndex::distance(&t, &t), Some(ByteCount(0))); - } - } - - #[test] - /// Check the property: - /// For all table indices t1, t2 such that t1 < t2, - /// distance(t1, t2) = None. - fn prop_table_index_distance_wrong_order_none() { - for _ in 0..REPEATS { - let (t1, t2) = any_table_index() - .zip(any_table_index()) - .find(|(t1, t2)| t1 < t2) - .unwrap(); - assert_eq!(TableIndex::distance(&t1, &t2), None); - } - } - - #[test] - /// Check the property: - /// For all table indices t1, t2 such that t1 > t2, - /// distance(t1, t2) = Some(v) where v is strictly positive. - fn prop_table_index_distance_some_positive() { - for _ in 0..REPEATS { - let (t1, t2) = any_table_index() - .zip(any_table_index()) - .find(|(t1, t2)| t1 > t2) - .unwrap(); - assert!(matches!(TableIndex::distance(&t1, &t2), Some(ByteCount(v)) if v > 0)); - } - } - - #[test] - /// Check the property: - /// For all table indices t, positive i such that i < distance (MAX, t) with MAX the largest - /// table index, - /// distance(t.increased(i), t) = Some(i). - fn prop_table_index_distance_increase() { - for _ in 0..REPEATS { - let (t, inc) = any_table_index() - .zip(any_usize()) - .find(|(t, inc)| { - (*inc as u128) < TableIndex::distance(&TableIndex::LAST, t).unwrap().0 - }) - .unwrap(); - assert_eq!( - TableIndex::distance(&t.increased(inc), &t).unwrap().0 as usize, - inc - ); - } - } - - #[test] - /// Check the property: - /// For all table indices t, t =? t = true. - fn prop_table_index_equality() { - for _ in 0..REPEATS { - let t = any_table_index().next().unwrap(); - assert_eq!( - std::cmp::PartialOrd::partial_cmp(&t, &t), - Some(std::cmp::Ordering::Equal) - ); - } - } - - #[test] - /// Check the property: - /// For all table indices t, positive i such that i < distance (MAX, t) with MAX the largest - /// table index, - /// t.increased(i) >? t = true. - fn prop_table_index_greater() { - for _ in 0..REPEATS { - let (t, inc) = any_table_index() - .zip(any_usize()) - .find(|(t, inc)| { - (*inc as u128) < TableIndex::distance(&TableIndex::LAST, t).unwrap().0 - }) - .unwrap(); - assert_eq!( - std::cmp::PartialOrd::partial_cmp(&t.increased(inc), &t), - Some(std::cmp::Ordering::Greater), - ); - } - } - - #[test] - /// Check the property: - /// For all table indices t, positive i such that i < distance (t, 0) with MAX the largest - /// table index, - /// t.decreased(i) = 127, - /// if s = State::new(t), and s.increment() was executed, then - /// s.increase(i) = RefreshBatchAndYield( - /// t.increased(i).aes_index, - /// t.increased(i).byte_index). - fn prop_state_increase_large() { - for _ in 0..REPEATS { - let (t, mut s, i) = any_table_index() - .zip(any_usize()) - .map(|(t, i)| (t, State::new(t), i)) - .find(|(t, _, i)| t.byte_index.0 + i >= 127) - .unwrap(); - s.increment(); - assert!(matches!( - s.increase(i), - ShiftAction::RefreshBatchAndYieldByte(t_, BufferPointer(p_)) - if t_ == t.increased(i).aes_index && p_ == t.increased(i).byte_index.0 - )); - } - } -} diff --git a/concrete-csprng/src/counter/test.rs b/concrete-csprng/src/counter/test.rs deleted file mode 100644 index ff6ec8d86a..0000000000 --- a/concrete-csprng/src/counter/test.rs +++ /dev/null @@ -1,187 +0,0 @@ -use super::*; -use crate::AesKey; -use rand::{thread_rng, Rng}; - -const REPEATS: usize = 1_000_000; - -fn any_table_index() -> impl Iterator { - std::iter::repeat_with(|| { - TableIndex::new( - AesIndex(thread_rng().gen()), - ByteIndex(thread_rng().gen::() % 16), - ) - }) -} - -fn any_usize() -> impl Iterator { - std::iter::repeat_with(|| thread_rng().gen()) -} - -fn any_children_count() -> impl Iterator { - std::iter::repeat_with(|| ChildrenCount(thread_rng().gen::() % 2048 + 1)) -} - -fn any_bytes_per_child() -> impl Iterator { - std::iter::repeat_with(|| BytesPerChild(thread_rng().gen::() % 2048 + 1)) -} - -fn any_key() -> impl Iterator { - std::iter::repeat_with(|| AesKey(thread_rng().gen())) -} - -#[test] -fn prop_fork_first_state_table_index() { - for _ in 0..REPEATS { - let ((((t, nc), nb), k), i) = any_table_index() - .zip(any_children_count()) - .zip(any_bytes_per_child()) - .zip(any_key()) - .zip(any_usize()) - .find(|((((t, nc), nb), _k), i)| { - TableIndex::distance(&TableIndex::LAST, t).unwrap().0 > (nc.0 * nb.0 + i) as u128 - }) - .unwrap(); - let original_generator = - SoftAesCtrGenerator::new(Some(k), Some(t), Some(t.increased(nc.0 * nb.0 + i))); - let mut forked_generator = original_generator.clone(); - let first_child = forked_generator.try_fork(nc, nb).unwrap().next().unwrap(); - assert_eq!( - original_generator.last_table_index(), - first_child.last_table_index() - ); - } -} - -#[test] -fn prop_fork_last_bound_table_index() { - for _ in 0..REPEATS { - let ((((t, nc), nb), k), i) = any_table_index() - .zip(any_children_count()) - .zip(any_bytes_per_child()) - .zip(any_key()) - .zip(any_usize()) - .find(|((((t, nc), nb), _k), i)| { - TableIndex::distance(&TableIndex::LAST, t).unwrap().0 > (nc.0 * nb.0 + i) as u128 - }) - .unwrap(); - let mut parent_generator = - SoftAesCtrGenerator::new(Some(k), Some(t), Some(t.increased(nc.0 * nb.0 + i))); - let last_child = parent_generator.try_fork(nc, nb).unwrap().last().unwrap(); - assert_eq!( - parent_generator.last_table_index().incremented(), - last_child.get_bound() - ); - } -} - -#[test] -fn prop_fork_parent_bound_table_index() { - for _ in 0..REPEATS { - let ((((t, nc), nb), k), i) = any_table_index() - .zip(any_children_count()) - .zip(any_bytes_per_child()) - .zip(any_key()) - .zip(any_usize()) - .find(|((((t, nc), nb), _k), i)| { - TableIndex::distance(&TableIndex::LAST, t).unwrap().0 > (nc.0 * nb.0 + i) as u128 - }) - .unwrap(); - let original_generator = - SoftAesCtrGenerator::new(Some(k), Some(t), Some(t.increased(nc.0 * nb.0 + i))); - let mut forked_generator = original_generator.clone(); - forked_generator.try_fork(nc, nb).unwrap().last().unwrap(); - assert_eq!(original_generator.get_bound(), forked_generator.get_bound()); - } -} - -#[test] -fn prop_fork_parent_state_table_index() { - for _ in 0..REPEATS { - let ((((t, nc), nb), k), i) = any_table_index() - .zip(any_children_count()) - .zip(any_bytes_per_child()) - .zip(any_key()) - .zip(any_usize()) - .find(|((((t, nc), nb), _k), i)| { - TableIndex::distance(&TableIndex::LAST, t).unwrap().0 > (nc.0 * nb.0 + i) as u128 - }) - .unwrap(); - let original_generator = - SoftAesCtrGenerator::new(Some(k), Some(t), Some(t.increased(nc.0 * nb.0 + i))); - let mut forked_generator = original_generator.clone(); - forked_generator.try_fork(nc, nb).unwrap().last().unwrap(); - assert!(original_generator.last_table_index() < forked_generator.last_table_index()); - } -} - -#[test] -fn prop_fork() { - for _ in 0..REPEATS { - let ((((t, nc), nb), k), i) = any_table_index() - .zip(any_children_count()) - .zip(any_bytes_per_child()) - .zip(any_key()) - .zip(any_usize()) - .find(|((((t, nc), nb), _k), i)| { - TableIndex::distance(&TableIndex::LAST, t).unwrap().0 > (nc.0 * nb.0 + i) as u128 - }) - .unwrap(); - let bytes_to_go = nc.0 * nb.0; - let original_generator = - SoftAesCtrGenerator::new(Some(k), Some(t), Some(t.increased(nc.0 * nb.0 + i))); - let mut forked_generator = original_generator.clone(); - let initial_output: Vec = original_generator.take(bytes_to_go as usize).collect(); - let forked_output: Vec = forked_generator - .try_fork(nc, nb) - .unwrap() - .flat_map(|child| child.collect::>()) - .collect(); - assert_eq!(initial_output, forked_output); - } -} - -#[test] -fn prop_fork_children_remaining_bytes() { - for _ in 0..REPEATS { - let ((((t, nc), nb), k), i) = any_table_index() - .zip(any_children_count()) - .zip(any_bytes_per_child()) - .zip(any_key()) - .zip(any_usize()) - .find(|((((t, nc), nb), _k), i)| { - TableIndex::distance(&TableIndex::LAST, t).unwrap().0 > (nc.0 * nb.0 + i) as u128 - }) - .unwrap(); - let mut generator = - SoftAesCtrGenerator::new(Some(k), Some(t), Some(t.increased(nc.0 * nb.0 + i))); - assert!(generator - .try_fork(nc, nb) - .unwrap() - .all(|c| c.remaining_bytes().0 == nb.0 as u128)); - } -} - -#[test] -fn prop_fork_parent_remaining_bytes() { - for _ in 0..REPEATS { - let ((((t, nc), nb), k), i) = any_table_index() - .zip(any_children_count()) - .zip(any_bytes_per_child()) - .zip(any_key()) - .zip(any_usize()) - .find(|((((t, nc), nb), _k), i)| { - TableIndex::distance(&TableIndex::LAST, t).unwrap().0 > (nc.0 * nb.0 + i) as u128 - }) - .unwrap(); - let bytes_to_go = nc.0 * nb.0; - let mut generator = - SoftAesCtrGenerator::new(Some(k), Some(t), Some(t.increased(nc.0 * nb.0 + i))); - let before_remaining_bytes = generator.remaining_bytes(); - let _ = generator.try_fork(nc, nb).unwrap(); - let after_remaining_bytes = generator.remaining_bytes(); - assert_eq!( - before_remaining_bytes.0 - after_remaining_bytes.0, - bytes_to_go as u128 - ); - } -} diff --git a/concrete-csprng/src/counter/test_aes.rs b/concrete-csprng/src/counter/test_aes.rs deleted file mode 100644 index 7eec1a378b..0000000000 --- a/concrete-csprng/src/counter/test_aes.rs +++ /dev/null @@ -1,20 +0,0 @@ -use super::*; -use crate::counter::AesCtr; - -#[test] -fn test_soft_hard_eq() { - // Checks that both the software and hardware prng outputs the same values. - let mut soft = SoftAesCtrGenerator::new( - Some(AesKey(0)), - Some(State::from_aes_counter(AesCtr(0))), - None, - ); - let mut hard = HardAesCtrGenerator::new( - Some(AesKey(0)), - Some(State::from_aes_counter(AesCtr(0))), - None, - ); - for _ in 0..1000 { - assert_eq!(soft.generate_next(), hard.generate_next()); - } -} diff --git a/concrete-csprng/src/generators/aes_ctr/block_cipher.rs b/concrete-csprng/src/generators/aes_ctr/block_cipher.rs new file mode 100644 index 0000000000..76c651a2da --- /dev/null +++ b/concrete-csprng/src/generators/aes_ctr/block_cipher.rs @@ -0,0 +1,20 @@ +use crate::generators::aes_ctr::index::AesIndex; +use crate::generators::aes_ctr::BYTES_PER_BATCH; + +/// Represents a key used in the AES ciphertext. +#[derive(Clone, Copy)] +pub struct AesKey(pub u128); + +/// A trait for AES block ciphers. +/// +/// Note: +/// ----- +/// +/// The block cipher are used in a batched manner (to reduce amortized cost on special hardware). +/// For this reason we only expose a `generate_batch` method. +pub trait AesBlockCipher: Clone + Send + Sync { + /// Instantiate a new generator from a secret key. + fn new(key: AesKey) -> Self; + /// Generates the batch corresponding to the given index. + fn generate_batch(&mut self, index: AesIndex) -> [u8; BYTES_PER_BATCH]; +} diff --git a/concrete-csprng/src/generators/aes_ctr/generic.rs b/concrete-csprng/src/generators/aes_ctr/generic.rs new file mode 100644 index 0000000000..5b1528ad04 --- /dev/null +++ b/concrete-csprng/src/generators/aes_ctr/generic.rs @@ -0,0 +1,377 @@ +use crate::generators::aes_ctr::block_cipher::{AesBlockCipher, AesKey}; +use crate::generators::aes_ctr::index::TableIndex; +use crate::generators::aes_ctr::states::{BufferPointer, ShiftAction, State}; +use crate::generators::aes_ctr::BYTES_PER_BATCH; +use crate::generators::{ByteCount, BytesPerChild, ChildrenCount, ForkError}; + +// Usually, to work with iterators and parallel iterators, we would use opaque types such as +// `impl Iterator<..>`. Unfortunately, it is not yet possible to return existential types in +// traits, which we would need for `RandomGenerator`. For this reason, we have to use the +// full type name where needed. Hence the following trait aliases definition: + +/// A type alias for the children iterator closure type. +pub type ChildrenClosure = + fn((usize, (BlockCipher, TableIndex, BytesPerChild))) -> AesCtrGenerator; + +/// A type alias for the children iterator type. +pub type ChildrenIterator = std::iter::Map< + std::iter::Zip< + std::ops::Range, + std::iter::Repeat<(BlockCipher, TableIndex, BytesPerChild)>, + >, + ChildrenClosure, +>; + +/// A type implementing the `RandomGenerator` api using the AES block cipher in counter mode. +#[derive(Clone)] +pub struct AesCtrGenerator { + // The block cipher used in the background + pub(crate) block_cipher: BlockCipher, + // The state corresponding to the latest yielded byte. + pub(crate) state: State, + // The bound, that is the first illegal index. + pub(crate) bound: TableIndex, + // The last legal index. This makes bound check faster. + pub(crate) last: TableIndex, + // The buffer containing the current batch of aes calls. + pub(crate) buffer: [u8; BYTES_PER_BATCH], +} + +#[allow(unused)] // to please clippy when tests are not activated +impl AesCtrGenerator { + /// Generates a new csprng. + /// + /// Note : + /// ------ + /// + /// The `start_index` given as input, points to the first byte that will be outputted by the + /// generator. If not given, this one is automatically set to the second table index (the + /// first table index is not used to prevent an edge case from happening). + /// The `bound_index` given as input, points to the first byte that can __not__ be legally + /// outputted by the generator. If not give, the bound is automatically set to the last + /// table index. + pub fn new( + key: AesKey, + start_index: Option, + bound_index: Option, + ) -> AesCtrGenerator { + AesCtrGenerator::from_block_cipher( + BlockCipher::new(key), + start_index.unwrap_or(TableIndex::SECOND), + bound_index.unwrap_or(TableIndex::LAST), + ) + } + + /// Generates a csprng from an existing block cipher. + pub fn from_block_cipher( + block_cipher: BlockCipher, + start_index: TableIndex, + bound_index: TableIndex, + ) -> AesCtrGenerator { + assert!(start_index < bound_index); + let last = bound_index.decremented(); + let buffer = [0u8; BYTES_PER_BATCH]; + let state = State::new(start_index); + AesCtrGenerator { + block_cipher, + state, + bound: bound_index, + last, + buffer, + } + } + + /// Returns the table index related to the previous byte. + pub fn table_index(&self) -> TableIndex { + self.state.table_index() + } + + /// Returns the bound of the generator if any. + /// + /// The bound is the table index of the first byte that can not be outputted by the generator. + pub fn get_bound(&self) -> TableIndex { + self.bound + } + + /// Returns whether the generator is bounded or not. + pub fn is_bounded(&self) -> bool { + self.bound != TableIndex::LAST + } + + /// Computes the number of bytes that can still be outputted by the generator. + /// + /// Note : + /// ------ + /// + /// Note that `ByteCount` uses the `u128` datatype to store the byte count. Unfortunately, the + /// number of remaining bytes is in ⟦0;2¹³² -1⟧. When the number is greater than 2¹²⁸ - 1, + /// we saturate the count at 2¹²⁸ - 1. + pub fn remaining_bytes(&self) -> ByteCount { + TableIndex::distance(&self.last, &self.state.table_index()).unwrap() + } + + /// Yields the next random byte. + pub fn generate_next(&mut self) -> u8 { + self.next() + .expect("Tried to generate a byte after the bound.") + } + + /// Tries to fork the current generator into `n_child` generators each able to yield + /// `child_bytes` random bytes. + pub fn try_fork( + &mut self, + n_children: ChildrenCount, + n_bytes: BytesPerChild, + ) -> Result, ForkError> { + if n_children.0 == 0 { + return Err(ForkError::ZeroChildrenCount); + } + if n_bytes.0 == 0 { + return Err(ForkError::ZeroBytesPerChild); + } + if !self.is_fork_in_bound(n_children, n_bytes) { + return Err(ForkError::ForkTooLarge); + } + + // The state currently stored in the parent generator points to the table index of the last + // generated byte. The first index to be generated is the next one : + let first_index = self.state.table_index().incremented(); + let output = (0..n_children.0) + .zip(std::iter::repeat(( + self.block_cipher.clone(), + first_index, + n_bytes, + ))) + .map( + // This map is a little weird because we need to cast the closure to a fn pointer + // that matches the signature of `ChildrenIterator`. + // Unfortunately, the compiler does not manage to coerce this one + // automatically. + (|(i, (block_cipher, first_index, n_bytes))| { + // The first index to be outputted by the child is the `first_index` shifted by + // the proper amount of `child_bytes`. + let child_first_index = first_index.increased(n_bytes.0 * i); + // The bound of the child is the first index of its next sibling. + let child_bound_index = first_index.increased(n_bytes.0 * (i + 1)); + AesCtrGenerator::from_block_cipher( + block_cipher, + child_first_index, + child_bound_index, + ) + }) as ChildrenClosure, + ); + // The parent next index is the bound of the last child. + let next_index = first_index.increased(n_bytes.0 * n_children.0); + self.state = State::new(next_index); + + Ok(output) + } + + pub(crate) fn is_fork_in_bound( + &self, + n_child: ChildrenCount, + child_bytes: BytesPerChild, + ) -> bool { + let mut end = self.state.table_index(); + end.increase(n_child.0 * child_bytes.0); + end < self.bound + } +} + +impl Iterator for AesCtrGenerator { + type Item = u8; + + fn next(&mut self) -> Option { + if self.state.table_index() >= self.last { + None + } else { + match self.state.increment() { + ShiftAction::YieldByte(BufferPointer(ptr)) => Some(self.buffer[ptr]), + ShiftAction::RefreshBatchAndYieldByte(aes_index, BufferPointer(ptr)) => { + self.buffer = self.block_cipher.generate_batch(aes_index); + Some(self.buffer[ptr]) + } + } + } + } +} + +#[cfg(test)] +pub mod aes_ctr_generic_test { + #![allow(unused)] // to please clippy when tests are not activated + + use super::*; + use crate::generators::aes_ctr::index::{AesIndex, ByteIndex}; + use crate::generators::aes_ctr::BYTES_PER_AES_CALL; + use rand::{thread_rng, Rng}; + + const REPEATS: usize = 1_000_000; + + pub fn any_table_index() -> impl Iterator { + std::iter::repeat_with(|| { + TableIndex::new( + AesIndex(thread_rng().gen()), + ByteIndex(thread_rng().gen::() % BYTES_PER_AES_CALL), + ) + }) + } + + pub fn any_usize() -> impl Iterator { + std::iter::repeat_with(|| thread_rng().gen()) + } + + pub fn any_children_count() -> impl Iterator { + std::iter::repeat_with(|| ChildrenCount(thread_rng().gen::() % 2048 + 1)) + } + + pub fn any_bytes_per_child() -> impl Iterator { + std::iter::repeat_with(|| BytesPerChild(thread_rng().gen::() % 2048 + 1)) + } + + pub fn any_key() -> impl Iterator { + std::iter::repeat_with(|| AesKey(thread_rng().gen())) + } + + /// Yields a valid fork: + /// a table index t, + /// a number of children nc, + /// a number of bytes per children nb + /// and a positive integer i such that: + /// increase(t, nc*nb+i) < MAX with MAX the largest table index. + /// Put differently, if we initialize a parent generator at t and fork it with (nc, nb), our + /// parent generator current index gets shifted to an index, distant of at least i bytes of + /// the max index. + pub fn any_valid_fork( + ) -> impl Iterator { + any_table_index() + .zip(any_children_count()) + .zip(any_bytes_per_child()) + .zip(any_usize()) + .map(|(((t, nc), nb), i)| (t, nc, nb, i)) + .filter(|(t, nc, nb, i)| { + TableIndex::distance(&TableIndex::LAST, t).unwrap().0 > (nc.0 * nb.0 + i) as u128 + }) + } + + /// Check the property: + /// On a valid fork, the table index of the first child is the same as the table index of + /// the parent before the fork. + pub fn prop_fork_first_state_table_index() { + for _ in 0..REPEATS { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let original_generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + let mut forked_generator = original_generator.clone(); + let first_child = forked_generator.try_fork(nc, nb).unwrap().next().unwrap(); + assert_eq!(original_generator.table_index(), first_child.table_index()); + } + } + + /// Check the property: + /// On a valid fork, the table index of the first byte yielded by the parent after the fork, + /// is the bound of the last child of the fork. + pub fn prop_fork_last_bound_table_index() { + for _ in 0..REPEATS { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let mut parent_generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + let last_child = parent_generator.try_fork(nc, nb).unwrap().last().unwrap(); + assert_eq!( + parent_generator.table_index().incremented(), + last_child.get_bound() + ); + } + } + + /// Check the property: + /// On a valid fork, the bound of the parent does not change. + pub fn prop_fork_parent_bound_table_index() { + for _ in 0..REPEATS { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let original_generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + let mut forked_generator = original_generator.clone(); + forked_generator.try_fork(nc, nb).unwrap().last().unwrap(); + assert_eq!(original_generator.get_bound(), forked_generator.get_bound()); + } + } + + /// Check the property: + /// On a valid fork, the parent table index is increased of the number of children + /// multiplied by the number of bytes per child. + pub fn prop_fork_parent_state_table_index() { + for _ in 0..REPEATS { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let original_generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + let mut forked_generator = original_generator.clone(); + forked_generator.try_fork(nc, nb).unwrap().last().unwrap(); + assert_eq!( + forked_generator.table_index(), + // Decrement accounts for the fact that the table index stored is the previous one + t.increased(nc.0 * nb.0).decremented() + ); + } + } + + /// Check the property: + /// On a valid fork, the bytes yielded by the children in the fork order form the same + /// sequence the parent would have had yielded no fork had happened. + pub fn prop_fork() { + for _ in 0..1000 { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let bytes_to_go = nc.0 * nb.0; + let original_generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + let mut forked_generator = original_generator.clone(); + let initial_output: Vec = original_generator.take(bytes_to_go as usize).collect(); + let forked_output: Vec = forked_generator + .try_fork(nc, nb) + .unwrap() + .flat_map(|child| child.collect::>()) + .collect(); + assert_eq!(initial_output, forked_output); + } + } + + /// Check the property: + /// On a valid fork, all children got a number of remaining bytes equals to the number of + /// bytes per child given as fork input. + pub fn prop_fork_children_remaining_bytes() { + for _ in 0..REPEATS { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let mut generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + assert!(generator + .try_fork(nc, nb) + .unwrap() + .all(|c| c.remaining_bytes().0 == nb.0 as u128)); + } + } + + /// Check the property: + /// On a valid fork, the number of remaining bybtes of the parent is reduced by the number + /// of children multiplied by the number of bytes per child. + pub fn prop_fork_parent_remaining_bytes() { + for _ in 0..REPEATS { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let bytes_to_go = nc.0 * nb.0; + let mut generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + let before_remaining_bytes = generator.remaining_bytes(); + let _ = generator.try_fork(nc, nb).unwrap(); + let after_remaining_bytes = generator.remaining_bytes(); + assert_eq!( + before_remaining_bytes.0 - after_remaining_bytes.0, + bytes_to_go as u128 + ); + } + } +} diff --git a/concrete-csprng/src/generators/aes_ctr/index.rs b/concrete-csprng/src/generators/aes_ctr/index.rs new file mode 100644 index 0000000000..e8a47af1cf --- /dev/null +++ b/concrete-csprng/src/generators/aes_ctr/index.rs @@ -0,0 +1,365 @@ +use crate::generators::aes_ctr::BYTES_PER_AES_CALL; +use crate::generators::ByteCount; +use std::cmp::Ordering; + +/// A structure representing an [aes index](#coarse-grained-pseudo-random-table-lookup). +#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] +pub struct AesIndex(pub u128); + +/// A structure representing a [byte index](#fine-grained-pseudo-random-table-lookup). +#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] +pub struct ByteIndex(pub usize); + +/// A structure representing a [table index](#fine-grained-pseudo-random-table-lookup) +#[derive(Clone, Copy, Debug)] +pub struct TableIndex { + pub(crate) aes_index: AesIndex, + pub(crate) byte_index: ByteIndex, +} + +impl TableIndex { + /// The first table index. + pub const FIRST: TableIndex = TableIndex { + aes_index: AesIndex(0), + byte_index: ByteIndex(0), + }; + + /// The second table index. + pub const SECOND: TableIndex = TableIndex { + aes_index: AesIndex(0), + byte_index: ByteIndex(1), + }; + + /// The last table index. + pub const LAST: TableIndex = TableIndex { + aes_index: AesIndex(u128::MAX), + byte_index: ByteIndex(BYTES_PER_AES_CALL - 1), + }; + + /// Creates a table index from an aes index and a byte index. + #[allow(unused)] // to please clippy when tests are not activated + pub fn new(aes_index: AesIndex, byte_index: ByteIndex) -> Self { + assert!(byte_index.0 < BYTES_PER_AES_CALL); + TableIndex { + aes_index, + byte_index, + } + } + + /// Shifts the table index forward of `shift` bytes. + pub fn increase(&mut self, shift: usize) { + let total = self.byte_index.0 + shift; + self.byte_index.0 = total % BYTES_PER_AES_CALL; + self.aes_index.0 = self + .aes_index + .0 + .wrapping_add(total as u128 / BYTES_PER_AES_CALL as u128); + } + + /// Shifts the table index backward of `shift` bytes. + pub fn decrease(&mut self, shift: usize) { + let remainder = shift % BYTES_PER_AES_CALL; + if remainder <= self.byte_index.0 { + self.aes_index.0 = self + .aes_index + .0 + .wrapping_sub((shift / BYTES_PER_AES_CALL) as u128); + self.byte_index.0 -= remainder; + } else { + self.aes_index.0 = self + .aes_index + .0 + .wrapping_sub((shift / BYTES_PER_AES_CALL) as u128 + 1); + self.byte_index.0 += BYTES_PER_AES_CALL - remainder; + } + } + + /// Shifts the table index forward of one byte. + pub fn increment(&mut self) { + self.increase(1) + } + + /// Shifts the table index backward of one byte. + pub fn decrement(&mut self) { + self.decrease(1) + } + + /// Returns the table index shifted forward by `shift` bytes. + pub fn increased(mut self, shift: usize) -> Self { + self.increase(shift); + self + } + + /// Returns the table index shifted backward by `shift` bytes. + #[allow(unused)] // to please clippy when tests are not activated + pub fn decreased(mut self, shift: usize) -> Self { + self.decrease(shift); + self + } + + /// Returns the table index to the next byte. + pub fn incremented(mut self) -> Self { + self.increment(); + self + } + + /// Returns the table index to the previous byte. + pub fn decremented(mut self) -> Self { + self.decrement(); + self + } + + /// Returns the distance between two table indices in bytes. + /// + /// Note: + /// ----- + /// + /// This method assumes that the `larger` input is, well, larger than the `smaller` input. If + /// this is not the case, the method returns `None`. Also, note that `ByteCount` uses the + /// `u128` datatype to store the byte count. Unfortunately, the number of bytes between two + /// table indices is in ⟦0;2¹³² -1⟧. When the distance is greater than 2¹²⁸ - 1, we saturate + /// the count at 2¹²⁸ - 1. + pub fn distance(larger: &Self, smaller: &Self) -> Option { + match std::cmp::Ord::cmp(larger, smaller) { + Ordering::Less => None, + Ordering::Equal => Some(ByteCount(0)), + Ordering::Greater => { + let mut result = larger.aes_index.0 - smaller.aes_index.0; + result = result.saturating_mul(BYTES_PER_AES_CALL as u128); + result = result.saturating_add(larger.byte_index.0 as u128); + result = result.saturating_sub(smaller.byte_index.0 as u128); + Some(ByteCount(result)) + } + } + } +} + +impl Eq for TableIndex {} + +impl PartialEq for TableIndex { + fn eq(&self, other: &Self) -> bool { + matches!(self.partial_cmp(other), Some(Ordering::Equal)) + } +} + +impl PartialOrd for TableIndex { + fn partial_cmp(&self, other: &Self) -> Option { + match self.aes_index.partial_cmp(&other.aes_index) { + Some(Ordering::Equal) => self.byte_index.partial_cmp(&other.byte_index), + other => other, + } + } +} + +impl Ord for TableIndex { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap() + } +} + +#[cfg(test)] +mod test { + use super::*; + use rand::{thread_rng, Rng}; + + const REPEATS: usize = 1_000_000; + + fn any_table_index() -> impl Iterator { + std::iter::repeat_with(|| { + TableIndex::new( + AesIndex(thread_rng().gen()), + ByteIndex(thread_rng().gen::() % BYTES_PER_AES_CALL), + ) + }) + } + + fn any_usize() -> impl Iterator { + std::iter::repeat_with(|| thread_rng().gen()) + } + + #[test] + #[should_panic] + /// Verifies that the constructor of `TableIndex` panics when the byte index is too large. + fn test_table_index_new_panic() { + TableIndex::new(AesIndex(12), ByteIndex(144)); + } + + #[test] + /// Verifies that the `TableIndex` wraps nicely with predecessor + fn test_table_index_predecessor_edge() { + assert_eq!(TableIndex::FIRST.decremented(), TableIndex::LAST); + } + + #[test] + /// Verifies that the `TableIndex` wraps nicely with successor + fn test_table_index_successor_edge() { + assert_eq!(TableIndex::LAST.incremented(), TableIndex::FIRST); + } + + #[test] + /// Check that the table index distance saturates nicely. + fn prop_table_index_distance_saturates() { + assert_eq!( + TableIndex::distance(&TableIndex::LAST, &TableIndex::FIRST) + .unwrap() + .0, + u128::MAX + ) + } + + #[test] + /// Check the property: + /// For all table indices t, + /// distance(t, t) = Some(0). + fn prop_table_index_distance_zero() { + for _ in 0..REPEATS { + let t = any_table_index().next().unwrap(); + assert_eq!(TableIndex::distance(&t, &t), Some(ByteCount(0))); + } + } + + #[test] + /// Check the property: + /// For all table indices t1, t2 such that t1 < t2, + /// distance(t1, t2) = None. + fn prop_table_index_distance_wrong_order_none() { + for _ in 0..REPEATS { + let (t1, t2) = any_table_index() + .zip(any_table_index()) + .find(|(t1, t2)| t1 < t2) + .unwrap(); + assert_eq!(TableIndex::distance(&t1, &t2), None); + } + } + + #[test] + /// Check the property: + /// For all table indices t1, t2 such that t1 > t2, + /// distance(t1, t2) = Some(v) where v is strictly positive. + fn prop_table_index_distance_some_positive() { + for _ in 0..REPEATS { + let (t1, t2) = any_table_index() + .zip(any_table_index()) + .find(|(t1, t2)| t1 > t2) + .unwrap(); + assert!(matches!(TableIndex::distance(&t1, &t2), Some(ByteCount(v)) if v > 0)); + } + } + + #[test] + /// Check the property: + /// For all table indices t, positive i such that i < distance (MAX, t) with MAX the largest + /// table index, + /// distance(t.increased(i), t) = Some(i). + fn prop_table_index_distance_increase() { + for _ in 0..REPEATS { + let (t, inc) = any_table_index() + .zip(any_usize()) + .find(|(t, inc)| { + (*inc as u128) < TableIndex::distance(&TableIndex::LAST, t).unwrap().0 + }) + .unwrap(); + assert_eq!( + TableIndex::distance(&t.increased(inc), &t).unwrap().0 as usize, + inc + ); + } + } + + #[test] + /// Check the property: + /// For all table indices t, t =? t = true. + fn prop_table_index_equality() { + for _ in 0..REPEATS { + let t = any_table_index().next().unwrap(); + assert_eq!( + std::cmp::PartialOrd::partial_cmp(&t, &t), + Some(std::cmp::Ordering::Equal) + ); + } + } + + #[test] + /// Check the property: + /// For all table indices t, positive i such that i < distance (MAX, t) with MAX the largest + /// table index, + /// t.increased(i) >? t = true. + fn prop_table_index_greater() { + for _ in 0..REPEATS { + let (t, inc) = any_table_index() + .zip(any_usize()) + .find(|(t, inc)| { + (*inc as u128) < TableIndex::distance(&TableIndex::LAST, t).unwrap().0 + }) + .unwrap(); + assert_eq!( + std::cmp::PartialOrd::partial_cmp(&t.increased(inc), &t), + Some(std::cmp::Ordering::Greater), + ); + } + } + + #[test] + /// Check the property: + /// For all table indices t, positive i such that i < distance (t, 0) with MAX the largest + /// table index, + /// t.decreased(i) = rayon::iter::Map< + rayon::iter::Zip< + rayon::range::Iter, + rayon::iter::RepeatN<(BlockCipher, TableIndex, BytesPerChild)>, + >, + fn((usize, (BlockCipher, TableIndex, BytesPerChild))) -> AesCtrGenerator, +>; + +impl AesCtrGenerator { + /// Tries to fork the current generator into `n_child` generators each able to yield + /// `child_bytes` random bytes as a parallel iterator. + /// + /// # Notes + /// + /// This method necessitate the "multithread" feature. + pub fn par_try_fork( + &mut self, + n_children: ChildrenCount, + n_bytes: BytesPerChild, + ) -> Result, ForkError> + where + BlockCipher: Send + Sync, + { + use rayon::prelude::*; + + if n_children.0 == 0 { + return Err(ForkError::ZeroChildrenCount); + } + if n_bytes.0 == 0 { + return Err(ForkError::ZeroBytesPerChild); + } + if !self.is_fork_in_bound(n_children, n_bytes) { + return Err(ForkError::ForkTooLarge); + } + + // The state currently stored in the parent generator points to the table index of the last + // generated byte. The first index to be generated is the next one : + let first_index = self.state.table_index().incremented(); + let output = (0..n_children.0) + .into_par_iter() + .zip(rayon::iter::repeatn( + (self.block_cipher.clone(), first_index, n_bytes), + n_children.0, + )) + .map( + // This map is a little weird because we need to cast the closure to a fn pointer + // that matches the signature of `ChildrenIterator`. Unfortunately, + // the compiler does not manage to coerce this one automatically. + (|(i, (block_cipher, first_index, n_bytes))| { + // The first index to be outputted by the child is the `first_index` shifted by + // the proper amount of `child_bytes`. + let child_first_index = first_index.increased(n_bytes.0 * i); + // The bound of the child is the first index of its next sibling. + let child_bound_index = first_index.increased(n_bytes.0 * (i + 1)); + AesCtrGenerator::from_block_cipher( + block_cipher, + child_first_index, + child_bound_index, + ) + }) as ChildrenClosure, + ); + // The parent next index is the bound of the last child. + let next_index = first_index.increased(n_bytes.0 * n_children.0); + self.state = State::new(next_index); + + Ok(output) + } +} + +#[cfg(test)] +pub mod aes_ctr_parallel_generic_tests { + + use super::*; + use crate::generators::aes_ctr::aes_ctr_generic_test::{any_key, any_valid_fork}; + use rayon::prelude::*; + + const REPEATS: usize = 1_000_000; + + /// Check the property: + /// On a valid fork, the table index of the first child is the same as the table index of + /// the parent before the fork. + pub fn prop_fork_first_state_table_index() { + for _ in 0..REPEATS { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let original_generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + let mut forked_generator = original_generator.clone(); + let first_child = forked_generator + .par_try_fork(nc, nb) + .unwrap() + .find_first(|_| true) + .unwrap(); + assert_eq!(original_generator.table_index(), first_child.table_index()); + } + } + + /// Check the property: + /// On a valid fork, the table index of the first byte yielded by the parent after the fork, + /// is the bound of the last child of the fork. + pub fn prop_fork_last_bound_table_index() { + for _ in 0..REPEATS { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let mut parent_generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + let last_child = parent_generator + .par_try_fork(nc, nb) + .unwrap() + .find_last(|_| true) + .unwrap(); + assert_eq!( + parent_generator.table_index().incremented(), + last_child.get_bound() + ); + } + } + + /// Check the property: + /// On a valid fork, the bound of the parent does not change. + pub fn prop_fork_parent_bound_table_index() { + for _ in 0..REPEATS { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let original_generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + let mut forked_generator = original_generator.clone(); + forked_generator + .par_try_fork(nc, nb) + .unwrap() + .find_last(|_| true) + .unwrap(); + assert_eq!(original_generator.get_bound(), forked_generator.get_bound()); + } + } + + /// Check the property: + /// On a valid fork, the parent table index is increased of the number of children + /// multiplied by the number of bytes per child. + pub fn prop_fork_parent_state_table_index() { + for _ in 0..REPEATS { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let original_generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + let mut forked_generator = original_generator.clone(); + forked_generator + .par_try_fork(nc, nb) + .unwrap() + .find_last(|_| true) + .unwrap(); + assert_eq!( + forked_generator.table_index(), + // Decrement accounts for the fact that the table index stored is the previous one + t.increased(nc.0 * nb.0).decremented() + ); + } + } + + /// Check the property: + /// On a valid fork, the bytes yielded by the children in the fork order form the same + /// sequence the parent would have had yielded no fork had happened. + pub fn prop_fork() { + for _ in 0..1000 { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let bytes_to_go = nc.0 * nb.0; + let original_generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + let mut forked_generator = original_generator.clone(); + let initial_output: Vec = original_generator.take(bytes_to_go as usize).collect(); + let forked_output: Vec = forked_generator + .par_try_fork(nc, nb) + .unwrap() + .flat_map(|child| child.collect::>()) + .collect(); + assert_eq!(initial_output, forked_output); + } + } + + /// Check the property: + /// On a valid fork, all children got a number of remaining bytes equals to the number of + /// bytes per child given as fork input. + pub fn prop_fork_children_remaining_bytes() { + for _ in 0..REPEATS { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let mut generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + assert!(generator + .par_try_fork(nc, nb) + .unwrap() + .all(|c| c.remaining_bytes().0 == nb.0 as u128)); + } + } + + /// Check the property: + /// On a valid fork, the number of remaining bybtes of the parent is reduced by the number + /// of children multiplied by the number of bytes per child. + pub fn prop_fork_parent_remaining_bytes() { + for _ in 0..REPEATS { + let (t, nc, nb, i) = any_valid_fork().next().unwrap(); + let k = any_key().next().unwrap(); + let bytes_to_go = nc.0 * nb.0; + let mut generator = + AesCtrGenerator::::new(k, Some(t), Some(t.increased(nc.0 * nb.0 + i))); + let before_remaining_bytes = generator.remaining_bytes(); + let _ = generator.par_try_fork(nc, nb).unwrap(); + let after_remaining_bytes = generator.remaining_bytes(); + assert_eq!( + before_remaining_bytes.0 - after_remaining_bytes.0, + bytes_to_go as u128 + ); + } + } +} diff --git a/concrete-csprng/src/generators/aes_ctr/states.rs b/concrete-csprng/src/generators/aes_ctr/states.rs new file mode 100644 index 0000000000..bb9f2d91a6 --- /dev/null +++ b/concrete-csprng/src/generators/aes_ctr/states.rs @@ -0,0 +1,176 @@ +use crate::generators::aes_ctr::index::{AesIndex, TableIndex}; +use crate::generators::aes_ctr::BYTES_PER_BATCH; + +/// A pointer to the next byte to be outputted by the generator. +#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] +pub struct BufferPointer(pub usize); + +/// A structure representing the current state of generator using batched aes-ctr approach. +#[derive(Debug, Clone, Copy)] +pub struct State { + table_index: TableIndex, + buffer_pointer: BufferPointer, +} + +/// A structure representing the action to be taken by the generator after shifting its state. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ShiftAction { + /// Yield the byte pointed to by the 0-th field. + YieldByte(BufferPointer), + /// Refresh the buffer starting from the 0-th field, and yield the byte pointed to by the 0-th + /// field. + RefreshBatchAndYieldByte(AesIndex, BufferPointer), +} + +impl State { + /// Creates a new state from the initial table index. + /// + /// Note : + /// ------ + /// + /// The `table_index` input, is the __first__ table index that will be outputted on the next + /// call to `increment`. Put differently, the current table index of the newly created state + /// is the predecessor of this one. + pub fn new(table_index: TableIndex) -> Self { + // We ensure that the table index is not the first one, to prevent wrapping on `decrement`, + // and yielding `RefreshBatchAndYield(AesIndex::MAX, ...)` on the first increment + // (which would lead to loading a non continuous batch). + assert_ne!(table_index, TableIndex::FIRST); + State { + // To ensure that the first yielded table index is the proper one, we decrement the + // table index. + table_index: table_index.decremented(), + // To ensure that the first `ShiftAction` will be a `RefreshBatchAndYieldByte`, we set + // the buffer to the last allowed value. + buffer_pointer: BufferPointer(BYTES_PER_BATCH - 1), + } + } + + /// Shifts the state forward of `shift` bytes. + pub fn increase(&mut self, shift: usize) -> ShiftAction { + self.table_index.increase(shift); + let total_batch_index = self.buffer_pointer.0 + shift; + if total_batch_index > BYTES_PER_BATCH - 1 { + self.buffer_pointer.0 = self.table_index.byte_index.0; + ShiftAction::RefreshBatchAndYieldByte(self.table_index.aes_index, self.buffer_pointer) + } else { + self.buffer_pointer.0 = total_batch_index; + ShiftAction::YieldByte(self.buffer_pointer) + } + } + + /// Shifts the state forward of one byte. + pub fn increment(&mut self) -> ShiftAction { + self.increase(1) + } + + /// Returns the current table index. + pub fn table_index(&self) -> TableIndex { + self.table_index + } +} + +impl Default for State { + fn default() -> Self { + State::new(TableIndex::FIRST) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::generators::aes_ctr::index::ByteIndex; + use crate::generators::aes_ctr::BYTES_PER_AES_CALL; + use rand::{thread_rng, Rng}; + + const REPEATS: usize = 1_000_000; + + fn any_table_index() -> impl Iterator { + std::iter::repeat_with(|| { + TableIndex::new( + AesIndex(thread_rng().gen()), + ByteIndex(thread_rng().gen::() % BYTES_PER_AES_CALL), + ) + }) + } + + fn any_usize() -> impl Iterator { + std::iter::repeat_with(|| thread_rng().gen()) + } + + #[test] + /// Check the property: + /// For all table indices t, + /// State::new(t).increment() = RefreshBatchAndYield(t.aes_index, t.byte_index) + fn prop_state_new_increment() { + for _ in 0..REPEATS { + let (t, mut s) = any_table_index() + .map(|t| (t, State::new(t))) + .next() + .unwrap(); + assert!(matches!( + s.increment(), + ShiftAction::RefreshBatchAndYieldByte(t_, BufferPointer(p_)) if t_ == t.aes_index && p_ == t.byte_index.0 + )) + } + } + + #[test] + /// Check the property: + /// For all states s, table indices t, positive integer i + /// if s = State::new(t), then t.increased(i) = s.increased(i-1).table_index(). + fn prop_state_increase_table_index() { + for _ in 0..REPEATS { + let (t, mut s, i) = any_table_index() + .zip(any_usize()) + .map(|(t, i)| (t, State::new(t), i)) + .next() + .unwrap(); + s.increase(i); + assert_eq!(s.table_index(), t.increased(i - 1)) + } + } + + #[test] + /// Check the property: + /// For all table indices t, positive integer i such as t.byte_index + i < 127, + /// if s = State::new(t), and s.increment() was executed, then + /// s.increase(i) = YieldByte(t.byte_index + i). + fn prop_state_increase_small() { + for _ in 0..REPEATS { + let (t, mut s, i) = any_table_index() + .zip(any_usize()) + .map(|(t, i)| (t, State::new(t), i % BYTES_PER_BATCH)) + .find(|(t, _, i)| t.byte_index.0 + i < BYTES_PER_BATCH - 1) + .unwrap(); + s.increment(); + assert!(matches!( + s.increase(i), + ShiftAction::YieldByte(BufferPointer(p_)) if p_ == t.byte_index.0 + i + )); + } + } + + #[test] + /// Check the property: + /// For all table indices t, positive integer i such as t.byte_index + i >= 127, + /// if s = State::new(t), and s.increment() was executed, then + /// s.increase(i) = RefreshBatchAndYield( + /// t.increased(i).aes_index, + /// t.increased(i).byte_index). + fn prop_state_increase_large() { + for _ in 0..REPEATS { + let (t, mut s, i) = any_table_index() + .zip(any_usize()) + .map(|(t, i)| (t, State::new(t), i)) + .find(|(t, _, i)| t.byte_index.0 + i >= BYTES_PER_BATCH - 1) + .unwrap(); + s.increment(); + assert!(matches!( + s.increase(i), + ShiftAction::RefreshBatchAndYieldByte(t_, BufferPointer(p_)) + if t_ == t.increased(i).aes_index && p_ == t.increased(i).byte_index.0 + )); + } + } +} diff --git a/concrete-csprng/src/aesni.rs b/concrete-csprng/src/generators/implem/aesni/block_cipher.rs similarity index 76% rename from concrete-csprng/src/aesni.rs rename to concrete-csprng/src/generators/implem/aesni/block_cipher.rs index 45cf449aae..1fec97babc 100644 --- a/concrete-csprng/src/aesni.rs +++ b/concrete-csprng/src/generators/implem/aesni/block_cipher.rs @@ -1,39 +1,31 @@ -//! A module implementing an `aes128-counter` random number generator, using `aesni` instructions. -//! -//! This module implements a cryptographically secure pseudorandom number generator -//! (CS-PRNG), using a fast streamcipher: aes128 in counter-mode (CTR). The implementation -//! is based on the [intel aesni white paper 323641-001 revision 3.0](https://www.intel.com/content/dam/doc/white-paper/advanced-encryption-standard-new-instructions-set-paper.pdf). -use crate::counter::{AesBatchedGenerator, AesIndex, AesKey}; +use crate::generators::aes_ctr::{AesBlockCipher, AesIndex, AesKey, BYTES_PER_BATCH}; use std::arch::x86_64::{ __m128i, _mm_aesenc_si128, _mm_aesenclast_si128, _mm_aeskeygenassist_si128, _mm_load_si128, _mm_shuffle_epi32, _mm_slli_si128, _mm_store_si128, _mm_xor_si128, }; use std::mem::transmute; +/// An aes block cipher implemenation which uses `aesni` instructions. #[derive(Clone)] -pub struct Generator { +pub struct AesniBlockCipher { // The set of round keys used for the aes encryption round_keys: [__m128i; 11], } -impl AesBatchedGenerator for Generator { - fn new(key: Option) -> Generator { +impl AesBlockCipher for AesniBlockCipher { + fn new(key: AesKey) -> AesniBlockCipher { if is_x86_feature_detected!("aes") && is_x86_feature_detected!("rdseed") && is_x86_feature_detected!("sse2") { - let round_keys = - generate_round_keys(key.unwrap_or_else(generate_initialization_vector)); - Generator { round_keys } + let round_keys = generate_round_keys(key); + AesniBlockCipher { round_keys } } else { - panic!( - "One of the `aes`, `rdseed`, or `sse2` instructions set was not fount. It is \ - currently mandatory to use `concrete-csprng`." - ) + panic!("One of the `aes`, `rdseed`, or `sse2` instructions set was not found") } } - fn generate_batch(&mut self, AesIndex(aes_ctr): AesIndex) -> [u8; 128] { + fn generate_batch(&mut self, AesIndex(aes_ctr): AesIndex) -> [u8; BYTES_PER_BATCH] { si128arr_to_u8arr(aes_encrypt_many( &u128_to_si128(aes_ctr), &u128_to_si128(aes_ctr + 1), @@ -48,11 +40,6 @@ impl AesBatchedGenerator for Generator { } } -fn generate_initialization_vector() -> AesKey { - // The initialization vector is a random value from rdseed - AesKey(si128_to_u128(rdseed_random_m128())) -} - fn generate_round_keys(key: AesKey) -> [__m128i; 11] { // The secret key is a random value from rdseed. let key = u128_to_si128(key.0); @@ -61,27 +48,6 @@ fn generate_round_keys(key: AesKey) -> [__m128i; 11] { keys } -// Generates a random 128 bits value from rdseed -fn rdseed_random_m128() -> __m128i { - let mut rand1: u64 = 0; - let mut rand2: u64 = 0; - unsafe { - loop { - if core::arch::x86_64::_rdseed64_step(&mut rand1) == 1 { - break; - } - } - loop { - if core::arch::x86_64::_rdseed64_step(&mut rand2) == 1 { - break; - } - } - #[repr(C)] - struct _tuple(u64, u64); - std::mem::transmute::<_tuple, __m128i>(_tuple(rand1, rand2)) - } -} - // Uses aes to encrypt many values at once. This allows a substantial speedup (around 30%) // compared to the naive approach. #[allow(clippy::too_many_arguments)] @@ -198,21 +164,16 @@ fn u128_to_si128(input: u128) -> __m128i { unsafe { transmute(input) } } +#[allow(unused)] // to please clippy when tests are not activated fn si128_to_u128(input: __m128i) -> u128 { unsafe { transmute(input) } } -fn si128arr_to_u8arr(input: [__m128i; 8]) -> [u8; 128] { +fn si128arr_to_u8arr(input: [__m128i; 8]) -> [u8; BYTES_PER_BATCH] { unsafe { transmute(input) } } -#[cfg(all( - test, - target_arch = "x86_64", - target_feature = "aes", - target_feature = "sse2", - target_feature = "rdseed" -))] +#[cfg(test)] mod test { use super::*; @@ -259,24 +220,4 @@ mod test { assert_eq!(CIPHERTEXT, si128_to_u128(*ct)); } } - - #[test] - fn test_uniformity() { - // Checks that the PRNG generates uniform numbers - let precision = 10f64.powi(-4); - let n_samples = 10_000_000_usize; - let mut generator = Generator::new(None); - let mut counts = [0usize; 256]; - let expected_prob: f64 = 1. / 256.; - for counter in 0..n_samples { - let generated = generator.generate_batch(AesIndex(counter as u128)); - for i in 0..128 { - counts[generated[i] as usize] += 1; - } - } - counts - .iter() - .map(|a| (*a as f64) / ((n_samples * 128) as f64)) - .for_each(|a| assert!((a - expected_prob) < precision)) - } } diff --git a/concrete-csprng/src/generators/implem/aesni/generator.rs b/concrete-csprng/src/generators/implem/aesni/generator.rs new file mode 100644 index 0000000000..651c865720 --- /dev/null +++ b/concrete-csprng/src/generators/implem/aesni/generator.rs @@ -0,0 +1,110 @@ +use crate::generators::aes_ctr::{AesCtrGenerator, AesKey, ChildrenIterator}; +use crate::generators::implem::aesni::block_cipher::AesniBlockCipher; +use crate::generators::{ByteCount, BytesPerChild, ChildrenCount, ForkError, RandomGenerator}; +use crate::seeders::Seed; + +/// A random number generator using the `aesni` instructions. +pub struct AesniRandomGenerator(pub(super) AesCtrGenerator); + +/// The children iterator used by [`AesniRandomGenerator`]. +/// +/// Yields children generators one by one. +pub struct AesniChildrenIterator(ChildrenIterator); + +impl Iterator for AesniChildrenIterator { + type Item = AesniRandomGenerator; + + fn next(&mut self) -> Option { + self.0.next().map(AesniRandomGenerator) + } +} + +impl RandomGenerator for AesniRandomGenerator { + type ChildrenIter = AesniChildrenIterator; + fn new(seed: Seed) -> Self { + AesniRandomGenerator(AesCtrGenerator::new(AesKey(seed.0), None, None)) + } + fn remaining_bytes(&self) -> ByteCount { + self.0.remaining_bytes() + } + fn try_fork( + &mut self, + n_children: ChildrenCount, + n_bytes: BytesPerChild, + ) -> Result { + self.0 + .try_fork(n_children, n_bytes) + .map(AesniChildrenIterator) + } +} + +impl Iterator for AesniRandomGenerator { + type Item = u8; + + fn next(&mut self) -> Option { + self.0.next() + } +} + +#[cfg(test)] +mod test { + use crate::generators::aes_ctr::aes_ctr_generic_test; + use crate::generators::implem::aesni::block_cipher::AesniBlockCipher; + use crate::generators::{generator_generic_test, AesniRandomGenerator}; + + #[test] + fn prop_fork_first_state_table_index() { + aes_ctr_generic_test::prop_fork_first_state_table_index::(); + } + + #[test] + fn prop_fork_last_bound_table_index() { + aes_ctr_generic_test::prop_fork_last_bound_table_index::(); + } + + #[test] + fn prop_fork_parent_bound_table_index() { + aes_ctr_generic_test::prop_fork_parent_bound_table_index::(); + } + + #[test] + fn prop_fork_parent_state_table_index() { + aes_ctr_generic_test::prop_fork_parent_state_table_index::(); + } + + #[test] + fn prop_fork() { + aes_ctr_generic_test::prop_fork::(); + } + + #[test] + fn prop_fork_children_remaining_bytes() { + aes_ctr_generic_test::prop_fork_children_remaining_bytes::(); + } + + #[test] + fn prop_fork_parent_remaining_bytes() { + aes_ctr_generic_test::prop_fork_parent_remaining_bytes::(); + } + + #[test] + fn test_uniformity() { + generator_generic_test::test_uniformity::(); + } + + #[test] + fn test_generator_determinism() { + generator_generic_test::test_generator_determinism::(); + } + + #[test] + fn test_fork() { + generator_generic_test::test_fork_children::(); + } + + #[test] + #[should_panic] + fn test_bounded_panic() { + generator_generic_test::test_bounded_none_should_panic::(); + } +} diff --git a/concrete-csprng/src/generators/implem/aesni/mod.rs b/concrete-csprng/src/generators/implem/aesni/mod.rs new file mode 100644 index 0000000000..4e43177f48 --- /dev/null +++ b/concrete-csprng/src/generators/implem/aesni/mod.rs @@ -0,0 +1,16 @@ +//! A module implementing a random number generator, using the x86_64 `aesni` instructions. +//! +//! This module implements a cryptographically secure pseudorandom number generator +//! (CS-PRNG), using a fast block cipher. The implementation is based on the +//! [intel aesni white paper 323641-001 revision 3.0](https://www.intel.com/content/dam/doc/white-paper/advanced-encryption-standard-new-instructions-set-paper.pdf). + +mod block_cipher; +use block_cipher::*; + +mod generator; +pub use generator::*; + +#[cfg(feature = "parallel")] +mod parallel; +#[cfg(feature = "parallel")] +pub use parallel::*; diff --git a/concrete-csprng/src/generators/implem/aesni/parallel.rs b/concrete-csprng/src/generators/implem/aesni/parallel.rs new file mode 100644 index 0000000000..06ed8c2902 --- /dev/null +++ b/concrete-csprng/src/generators/implem/aesni/parallel.rs @@ -0,0 +1,94 @@ +use super::*; +use crate::generators::aes_ctr::{AesCtrGenerator, ParallelChildrenIterator}; +use crate::generators::{BytesPerChild, ChildrenCount, ForkError, ParallelRandomGenerator}; +use rayon::iter::plumbing::{Consumer, ProducerCallback, UnindexedConsumer}; +use rayon::prelude::*; + +/// The parallel children iterator used by [`AesniRandomGenerator`]. +/// +/// Yields the children generators one by one. +#[allow(clippy::type_complexity)] +pub struct ParallelAesniChildrenIterator( + rayon::iter::Map< + ParallelChildrenIterator, + fn(AesCtrGenerator) -> AesniRandomGenerator, + >, +); + +impl ParallelIterator for ParallelAesniChildrenIterator { + type Item = AesniRandomGenerator; + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: UnindexedConsumer, + { + self.0.drive_unindexed(consumer) + } +} + +impl IndexedParallelIterator for ParallelAesniChildrenIterator { + fn len(&self) -> usize { + self.0.len() + } + fn drive>(self, consumer: C) -> C::Result { + self.0.drive(consumer) + } + fn with_producer>(self, callback: CB) -> CB::Output { + self.0.with_producer(callback) + } +} + +impl ParallelRandomGenerator for AesniRandomGenerator { + type ParChildrenIter = ParallelAesniChildrenIterator; + + fn par_try_fork( + &mut self, + n_children: ChildrenCount, + n_bytes: BytesPerChild, + ) -> Result { + self.0 + .par_try_fork(n_children, n_bytes) + .map(|iterator| ParallelAesniChildrenIterator(iterator.map(AesniRandomGenerator))) + } +} + +#[cfg(test)] + +mod test { + use crate::generators::aes_ctr::aes_ctr_parallel_generic_tests; + use crate::generators::implem::aesni::block_cipher::AesniBlockCipher; + + #[test] + fn prop_fork_first_state_table_index() { + aes_ctr_parallel_generic_tests::prop_fork_first_state_table_index::(); + } + + #[test] + fn prop_fork_last_bound_table_index() { + aes_ctr_parallel_generic_tests::prop_fork_last_bound_table_index::(); + } + + #[test] + fn prop_fork_parent_bound_table_index() { + aes_ctr_parallel_generic_tests::prop_fork_parent_bound_table_index::(); + } + + #[test] + fn prop_fork_parent_state_table_index() { + aes_ctr_parallel_generic_tests::prop_fork_parent_state_table_index::(); + } + + #[test] + fn prop_fork_ttt() { + aes_ctr_parallel_generic_tests::prop_fork::(); + } + + #[test] + fn prop_fork_children_remaining_bytes() { + aes_ctr_parallel_generic_tests::prop_fork_children_remaining_bytes::(); + } + + #[test] + fn prop_fork_parent_remaining_bytes() { + aes_ctr_parallel_generic_tests::prop_fork_parent_remaining_bytes::(); + } +} diff --git a/concrete-csprng/src/generators/implem/mod.rs b/concrete-csprng/src/generators/implem/mod.rs new file mode 100644 index 0000000000..49e62a3034 --- /dev/null +++ b/concrete-csprng/src/generators/implem/mod.rs @@ -0,0 +1,9 @@ +#[cfg(feature = "generator_x86_64_aesni")] +mod aesni; +#[cfg(feature = "generator_x86_64_aesni")] +pub use aesni::*; + +#[cfg(feature = "generator_soft")] +mod soft; +#[cfg(feature = "generator_soft")] +pub use soft::*; diff --git a/concrete-csprng/src/generators/implem/soft/block_cipher.rs b/concrete-csprng/src/generators/implem/soft/block_cipher.rs new file mode 100644 index 0000000000..aa763658b4 --- /dev/null +++ b/concrete-csprng/src/generators/implem/soft/block_cipher.rs @@ -0,0 +1,114 @@ +use crate::generators::aes_ctr::{ + AesBlockCipher, AesIndex, AesKey, AES_CALLS_PER_BATCH, BYTES_PER_AES_CALL, BYTES_PER_BATCH, +}; +use aes_soft::cipher::generic_array::GenericArray; +use aes_soft::cipher::{BlockCipher, NewBlockCipher}; +use aes_soft::Aes128; + +#[derive(Clone)] +pub struct SoftwareBlockCipher { + // Aes structure + aes: Aes128, +} + +impl AesBlockCipher for SoftwareBlockCipher { + fn new(key: AesKey) -> SoftwareBlockCipher { + let key: [u8; BYTES_PER_AES_CALL] = key.0.to_ne_bytes(); + let key = GenericArray::clone_from_slice(&key[..]); + let aes = Aes128::new(&key); + SoftwareBlockCipher { aes } + } + + fn generate_batch(&mut self, AesIndex(aes_ctr): AesIndex) -> [u8; BYTES_PER_BATCH] { + aes_encrypt_many( + aes_ctr, + aes_ctr + 1, + aes_ctr + 2, + aes_ctr + 3, + aes_ctr + 4, + aes_ctr + 5, + aes_ctr + 6, + aes_ctr + 7, + &self.aes, + ) + } +} + +// Uses aes to encrypt many values at once. This allows a substantial speedup (around 30%) +// compared to the naive approach. +#[allow(clippy::too_many_arguments)] +fn aes_encrypt_many( + message_1: u128, + message_2: u128, + message_3: u128, + message_4: u128, + message_5: u128, + message_6: u128, + message_7: u128, + message_8: u128, + cipher: &Aes128, +) -> [u8; BYTES_PER_BATCH] { + let mut b1 = GenericArray::clone_from_slice(&message_1.to_ne_bytes()[..]); + let mut b2 = GenericArray::clone_from_slice(&message_2.to_ne_bytes()[..]); + let mut b3 = GenericArray::clone_from_slice(&message_3.to_ne_bytes()[..]); + let mut b4 = GenericArray::clone_from_slice(&message_4.to_ne_bytes()[..]); + let mut b5 = GenericArray::clone_from_slice(&message_5.to_ne_bytes()[..]); + let mut b6 = GenericArray::clone_from_slice(&message_6.to_ne_bytes()[..]); + let mut b7 = GenericArray::clone_from_slice(&message_7.to_ne_bytes()[..]); + let mut b8 = GenericArray::clone_from_slice(&message_8.to_ne_bytes()[..]); + + cipher.encrypt_block(&mut b1); + cipher.encrypt_block(&mut b2); + cipher.encrypt_block(&mut b3); + cipher.encrypt_block(&mut b4); + cipher.encrypt_block(&mut b5); + cipher.encrypt_block(&mut b6); + cipher.encrypt_block(&mut b7); + cipher.encrypt_block(&mut b8); + + let output_array: [[u8; BYTES_PER_AES_CALL]; AES_CALLS_PER_BATCH] = [ + b1.into(), + b2.into(), + b3.into(), + b4.into(), + b5.into(), + b6.into(), + b7.into(), + b8.into(), + ]; + + unsafe { *{ output_array.as_ptr() as *const [u8; BYTES_PER_BATCH] } } +} + +#[cfg(test)] +mod test { + use super::*; + use std::convert::TryInto; + + // Test vector for aes128, from the FIPS publication 197 + const CIPHER_KEY: u128 = u128::from_be(0x000102030405060708090a0b0c0d0e0f); + const PLAINTEXT: u128 = u128::from_be(0x00112233445566778899aabbccddeeff); + const CIPHERTEXT: u128 = u128::from_be(0x69c4e0d86a7b0430d8cdb78070b4c55a); + + #[test] + fn test_encrypt_many_messages() { + // Checks that encrypting many plaintext at the same time gives the correct output. + let key: [u8; BYTES_PER_AES_CALL] = CIPHER_KEY.to_ne_bytes(); + let aes = Aes128::new(&GenericArray::from(key)); + let ciphertexts = aes_encrypt_many( + PLAINTEXT, PLAINTEXT, PLAINTEXT, PLAINTEXT, PLAINTEXT, PLAINTEXT, PLAINTEXT, PLAINTEXT, + &aes, + ); + let ciphertexts: [u8; BYTES_PER_BATCH] = ciphertexts[..].try_into().unwrap(); + for i in 0..8 { + assert_eq!( + u128::from_ne_bytes( + ciphertexts[BYTES_PER_AES_CALL * i..BYTES_PER_AES_CALL * (i + 1)] + .try_into() + .unwrap() + ), + CIPHERTEXT + ); + } + } +} diff --git a/concrete-csprng/src/generators/implem/soft/generator.rs b/concrete-csprng/src/generators/implem/soft/generator.rs new file mode 100644 index 0000000000..c1abc040eb --- /dev/null +++ b/concrete-csprng/src/generators/implem/soft/generator.rs @@ -0,0 +1,110 @@ +use crate::generators::aes_ctr::{AesCtrGenerator, AesKey, ChildrenIterator}; +use crate::generators::implem::soft::block_cipher::SoftwareBlockCipher; +use crate::generators::{ByteCount, BytesPerChild, ChildrenCount, ForkError, RandomGenerator}; +use crate::seeders::Seed; + +/// A random number generator using a software implementation. +pub struct SoftwareRandomGenerator(pub(super) AesCtrGenerator); + +/// The children iterator used by [`SoftwareRandomGenerator`]. +/// +/// Yields children generators one by one. +pub struct SoftwareChildrenIterator(ChildrenIterator); + +impl Iterator for SoftwareChildrenIterator { + type Item = SoftwareRandomGenerator; + + fn next(&mut self) -> Option { + self.0.next().map(SoftwareRandomGenerator) + } +} + +impl RandomGenerator for SoftwareRandomGenerator { + type ChildrenIter = SoftwareChildrenIterator; + fn new(seed: Seed) -> Self { + SoftwareRandomGenerator(AesCtrGenerator::new(AesKey(seed.0), None, None)) + } + fn remaining_bytes(&self) -> ByteCount { + self.0.remaining_bytes() + } + fn try_fork( + &mut self, + n_children: ChildrenCount, + n_bytes: BytesPerChild, + ) -> Result { + self.0 + .try_fork(n_children, n_bytes) + .map(SoftwareChildrenIterator) + } +} + +impl Iterator for SoftwareRandomGenerator { + type Item = u8; + + fn next(&mut self) -> Option { + self.0.next() + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::generators::aes_ctr::aes_ctr_generic_test; + use crate::generators::generator_generic_test; + + #[test] + fn prop_fork_first_state_table_index() { + aes_ctr_generic_test::prop_fork_first_state_table_index::(); + } + + #[test] + fn prop_fork_last_bound_table_index() { + aes_ctr_generic_test::prop_fork_last_bound_table_index::(); + } + + #[test] + fn prop_fork_parent_bound_table_index() { + aes_ctr_generic_test::prop_fork_parent_bound_table_index::(); + } + + #[test] + fn prop_fork_parent_state_table_index() { + aes_ctr_generic_test::prop_fork_parent_state_table_index::(); + } + + #[test] + fn prop_fork() { + aes_ctr_generic_test::prop_fork::(); + } + + #[test] + fn prop_fork_children_remaining_bytes() { + aes_ctr_generic_test::prop_fork_children_remaining_bytes::(); + } + + #[test] + fn prop_fork_parent_remaining_bytes() { + aes_ctr_generic_test::prop_fork_parent_remaining_bytes::(); + } + + #[test] + fn test_uniformity() { + generator_generic_test::test_uniformity::(); + } + + #[test] + fn test_fork() { + generator_generic_test::test_fork_children::(); + } + + #[test] + fn test_generator_determinism() { + generator_generic_test::test_generator_determinism::(); + } + + #[test] + #[should_panic] + fn test_bounded_panic() { + generator_generic_test::test_bounded_none_should_panic::(); + } +} diff --git a/concrete-csprng/src/generators/implem/soft/mod.rs b/concrete-csprng/src/generators/implem/soft/mod.rs new file mode 100644 index 0000000000..e6bfd3b7b7 --- /dev/null +++ b/concrete-csprng/src/generators/implem/soft/mod.rs @@ -0,0 +1,12 @@ +//! A module using a software fallback implementation of random number generator. + +mod block_cipher; +use block_cipher::*; + +mod generator; +pub use generator::*; + +#[cfg(feature = "parallel")] +mod parallel; +#[cfg(feature = "parallel")] +pub use parallel::*; diff --git a/concrete-csprng/src/generators/implem/soft/parallel.rs b/concrete-csprng/src/generators/implem/soft/parallel.rs new file mode 100644 index 0000000000..3d2a763b0b --- /dev/null +++ b/concrete-csprng/src/generators/implem/soft/parallel.rs @@ -0,0 +1,93 @@ +use super::*; +use crate::generators::aes_ctr::{AesCtrGenerator, ParallelChildrenIterator}; +use crate::generators::{BytesPerChild, ChildrenCount, ForkError, ParallelRandomGenerator}; +use rayon::iter::plumbing::{Consumer, ProducerCallback, UnindexedConsumer}; +use rayon::prelude::*; + +/// The parallel children iterator used by [`SoftwareRandomGenerator`]. +/// +/// Yields the children generators one by one. +#[allow(clippy::type_complexity)] +pub struct ParallelSoftwareChildrenIterator( + rayon::iter::Map< + ParallelChildrenIterator, + fn(AesCtrGenerator) -> SoftwareRandomGenerator, + >, +); + +impl ParallelIterator for ParallelSoftwareChildrenIterator { + type Item = SoftwareRandomGenerator; + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: UnindexedConsumer, + { + self.0.drive_unindexed(consumer) + } +} + +impl IndexedParallelIterator for ParallelSoftwareChildrenIterator { + fn len(&self) -> usize { + self.0.len() + } + fn drive>(self, consumer: C) -> C::Result { + self.0.drive(consumer) + } + fn with_producer>(self, callback: CB) -> CB::Output { + self.0.with_producer(callback) + } +} + +impl ParallelRandomGenerator for SoftwareRandomGenerator { + type ParChildrenIter = ParallelSoftwareChildrenIterator; + + fn par_try_fork( + &mut self, + n_children: ChildrenCount, + n_bytes: BytesPerChild, + ) -> Result { + self.0 + .par_try_fork(n_children, n_bytes) + .map(|iterator| ParallelSoftwareChildrenIterator(iterator.map(SoftwareRandomGenerator))) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::generators::aes_ctr::aes_ctr_parallel_generic_tests; + + #[test] + fn prop_fork_first_state_table_index() { + aes_ctr_parallel_generic_tests::prop_fork_first_state_table_index::(); + } + + #[test] + fn prop_fork_last_bound_table_index() { + aes_ctr_parallel_generic_tests::prop_fork_last_bound_table_index::(); + } + + #[test] + fn prop_fork_parent_bound_table_index() { + aes_ctr_parallel_generic_tests::prop_fork_parent_bound_table_index::(); + } + + #[test] + fn prop_fork_parent_state_table_index() { + aes_ctr_parallel_generic_tests::prop_fork_parent_state_table_index::(); + } + + #[test] + fn prop_fork() { + aes_ctr_parallel_generic_tests::prop_fork::(); + } + + #[test] + fn prop_fork_children_remaining_bytes() { + aes_ctr_parallel_generic_tests::prop_fork_children_remaining_bytes::(); + } + + #[test] + fn prop_fork_parent_remaining_bytes() { + aes_ctr_parallel_generic_tests::prop_fork_parent_remaining_bytes::(); + } +} diff --git a/concrete-csprng/src/generators/mod.rs b/concrete-csprng/src/generators/mod.rs new file mode 100644 index 0000000000..c02fcef55d --- /dev/null +++ b/concrete-csprng/src/generators/mod.rs @@ -0,0 +1,196 @@ +//! A module containing random generators objects. +//! +//! See [crate-level](`crate`) explanations. +use crate::seeders::Seed; +use std::error::Error; +use std::fmt::{Display, Formatter}; + +/// The number of children created when a generator is forked. +#[derive(Debug, Copy, Clone)] +pub struct ChildrenCount(pub usize); + +/// The number of bytes each child can generate, when a generator is forked. +#[derive(Debug, Copy, Clone)] +pub struct BytesPerChild(pub usize); + +/// A structure representing the number of bytes between two table indices. +#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] +pub struct ByteCount(pub u128); + +/// An error occuring during a generator fork. +#[derive(Debug)] +pub enum ForkError { + ForkTooLarge, + ZeroChildrenCount, + ZeroBytesPerChild, +} + +impl Display for ForkError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ForkError::ForkTooLarge => { + write!( + f, + "The children generators would output bytes after the parent bound. " + ) + } + ForkError::ZeroChildrenCount => { + write!( + f, + "The number of children in the fork must be greater than zero." + ) + } + ForkError::ZeroBytesPerChild => { + write!( + f, + "The number of bytes per child must be greater than zero." + ) + } + } + } +} +impl Error for ForkError {} + +/// A trait for cryptographically secure pseudo-random generators. +/// +/// See the [crate-level](#crate) documentation for details. +pub trait RandomGenerator: Iterator { + /// The iterator over children generators, returned by `try_fork` in case of success. + type ChildrenIter: Iterator; + + /// Creates a new generator from a seed. + /// + /// This operation is usually costly to perform, as the round keys need to be generated from the + /// seed. + fn new(seed: Seed) -> Self; + + /// Returns the number of bytes that can still be yielded by the generator before reaching its + /// bound. + /// + /// Note: + /// ----- + /// + /// A fresh generator can generate 2¹³² bytes. Unfortunately, no rust integer type in is able + /// to encode such a large number. Consequently [`ByteCount`] uses the largest integer type + /// available to encode this value: the `u128` type. For this reason, this method does not + /// effectively return the number of remaining bytes, but instead + /// `min(2¹²⁸-1, remaining_bytes)`. + fn remaining_bytes(&self) -> ByteCount; + + /// Returns the next byte of the stream, if the generator did not yet reach its bound. + fn next_byte(&mut self) -> Option { + self.next() + } + + /// Tries to fork the generator into an iterator of `n_children` new generators, each able to + /// yield `n_bytes` bytes. + /// + /// Note: + /// ----- + /// + /// To be successful, the number of remaining bytes for the parent generator must be larger than + /// `n_children*n_bytes`. + fn try_fork( + &mut self, + n_children: ChildrenCount, + n_bytes: BytesPerChild, + ) -> Result; +} + +/// A trait extending [`RandomGenerator`] to the parallel iterators of `rayon`. +#[cfg(feature = "parallel")] +pub trait ParallelRandomGenerator: RandomGenerator { + /// The iterator over children generators, returned by `par_try_fork` in case of success. + type ParChildrenIter: rayon::prelude::IndexedParallelIterator; + + /// Tries to fork the generator into a parallel iterator of `n_children` new generators, each + /// able to yield `n_bytes` bytes. + /// + /// Note: + /// ----- + /// + /// To be successful, the number of remaining bytes for the parent generator must be larger than + /// `n_children*n_bytes`. + fn par_try_fork( + &mut self, + n_children: ChildrenCount, + n_bytes: BytesPerChild, + ) -> Result; +} + +mod aes_ctr; + +mod implem; +pub use implem::*; + +#[cfg(test)] +pub mod generator_generic_test { + #![allow(unused)] // to please clippy when tests are not activated + use super::*; + use rand::Rng; + + const REPEATS: usize = 1_000; + + fn any_seed() -> impl Iterator { + std::iter::repeat_with(|| Seed(rand::thread_rng().gen())) + } + + /// Checks that the PRNG roughly generates uniform numbers + pub fn test_uniformity() { + for _ in 0..REPEATS { + let seed = any_seed().next().unwrap(); + let precision = 10f64.powi(-4); + let n_samples = 10_000_000_usize; + let mut generator = G::new(seed); + let mut counts = [0usize; 256]; + let expected_prob: f64 = 1. / 256.; + for _ in 0..n_samples { + counts[generator.next_byte().unwrap() as usize] += 1; + } + counts + .iter() + .map(|a| (*a as f64) / (n_samples as f64)) + .for_each(|a| assert!((a - expected_prob) < precision)) + } + } + + /// Checks that given a state and a key, the PRNG is determinist. + pub fn test_generator_determinism() { + for _ in 0..REPEATS { + let seed = any_seed().next().unwrap(); + let mut first_generator = G::new(seed); + let mut second_generator = G::new(seed); + for _ in 0..1024 { + assert_eq!(first_generator.next(), second_generator.next()); + } + } + } + + /// Checks that forks returns a bounded child, and that the proper number of bytes can be + /// generated. + pub fn test_fork_children() { + let mut gen = G::new(any_seed().next().unwrap()); + let mut bounded = gen + .try_fork(ChildrenCount(1), BytesPerChild(10)) + .unwrap() + .next() + .unwrap(); + assert_eq!(bounded.remaining_bytes(), ByteCount(10)); + for _ in 0..10 { + bounded.next(); + } + } + + // Checks that a bounded prng returns none when exceeding the allowed number of bytes. + pub fn test_bounded_none_should_panic() { + let mut gen = G::new(any_seed().next().unwrap()); + let mut bounded = gen + .try_fork(ChildrenCount(1), BytesPerChild(10)) + .unwrap() + .next() + .unwrap(); + for _ in 0..11 { + assert!(bounded.next().is_some()); + } + } +} diff --git a/concrete-csprng/src/lib.rs b/concrete-csprng/src/lib.rs index 13f41dc56a..8e2ce42063 100644 --- a/concrete-csprng/src/lib.rs +++ b/concrete-csprng/src/lib.rs @@ -1,305 +1,114 @@ #![deny(rustdoc::broken_intra_doc_links)] -//! Cryptographically secure pseudo random number generator, that uses AES in CTR mode. +//! Cryptographically secure pseudo random number generator. //! //! Welcome to the `concrete-csprng` documentation. //! -//! This crate contains a reasonably fast cryptographically secure pseudo-random number generator. +//! This crate provides a reasonably fast cryptographically secure pseudo-random number generator, +//! suited to work in a multithreaded setting. +//! +//! Random Generators +//! ================= +//! +//! The central abstraction of this crate is the [`RandomGenerator`](generators::RandomGenerator) +//! trait, which is implemented by different types, each supporting a different platform. In +//! essence, a type implementing [`RandomGenerator`](generators::RandomGenerator) is a type that +//! yields a new pseudo-random byte at each call to +//! [`next_byte`](generators::RandomGenerator::next_byte). Such a generator `g` can be seen as +//! enclosing a growing index into an imaginary array of pseudo-random bytes: +//! ```ascii +//! 0 1 2 3 4 5 6 7 8 9 M │ +//! ┏━┯━┯━┯━┯━┯━┯━┯━┯━┯━┯━━━┯━┓ │ +//! ┃ │ │ │ │ │ │ │ │ │ │...│ ┃ │ +//! ┗↥┷━┷━┷━┷━┷━┷━┷━┷━┷━┷━━━┷━┛ │ +//! g │ +//! │ +//! g.next_byte() │ +//! │ +//! 0 1 2 3 4 5 6 7 8 9 M │ +//! ┏━┯━┯━┯━┯━┯━┯━┯━┯━┯━┯━━━┯━┓ │ +//! ┃╳│ │ │ │ │ │ │ │ │ │...│ ┃ │ +//! ┗━┷↥┷━┷━┷━┷━┷━┷━┷━┷━┷━━━┷━┛ │ +//! g │ +//! │ +//! g.next_byte() │ legend: +//! │ ------- +//! 0 1 2 3 4 5 6 7 8 9 M │ ↥ : next byte to be yielded by g +//! ┏━┯━┯━┯━┯━┯━┯━┯━┯━┯━┯━━━┯━┓ │ │ │: byte not yet yielded by g +//! ┃╳│╳│ │ │ │ │ │ │ │ │...│ ┃ │ │╳│: byte already yielded by g +//! ┗━┷━┷↥┷━┷━┷━┷━┷━┷━┷━┷━━━┷━┛ │ +//! g 🭭 +//! ``` +//! +//! While being large, this imaginary array is still bounded to 2¹³² bytes. Consequently, a +//! generator is always bounded to a maximal index. That is, there is always a max amount of +//! elements of this array that can be yielded by the generator. By default, generators created via +//! [`new`](generators::RandomGenerator::new) are always bounded to M-1. +//! +//! Tree partition of the pseudo-random stream +//! ========================================== +//! +//! One particularity of this implementation is that you can use the +//! [`try_fork`](generators::RandomGenerator::try_fork) method to create an arbitrary partition tree +//! of a region of this array. Indeed, calling `try_fork(nc, nb)` yields `nc` new generators, each +//! able to yield `nb` bytes. The `try_fork` method ensures that the states and bounds of the parent +//! and children generators are set so as to prevent the same substream to be outputted +//! twice: +//! ```ascii +//! 0 1 2 3 4 5 6 7 8 9 M │ +//! ┏━┯━┯━┯━┯━┯━┯━┯━┯━┯━┯━━━┯━┓ │ +//! ┃P│P│P│P│P│P│P│P│P│P│...│P┃ │ +//! ┗↥┷━┷━┷━┷━┷━┷━┷━┷━┷━┷━━━┷━┛ │ +//! p │ +//! │ +//! (a,b) = p.fork(2,4) │ +//! │ +//! 0 1 2 3 4 5 6 7 8 9 M │ +//! ┏━┯━┯━┯━┯━┯━┯━┯━┯━┯━┯━━━┯━┓ │ +//! ┃A│A│A│A│B│B│B│B│P│P│...│P┃ │ +//! ┗↥┷━┷━┷━┷↥┷━┷━┷━┷↥┷━┷━━━┷━┛ │ +//! a b p │ +//! │ legend: +//! (c,d) = b.fork(2, 1) │ ------- +//! │ ↥ : next byte to be yielded by p +//! 0 1 2 3 4 5 6 7 8 9 M │ p +//! ┏━┯━┯━┯━┯━┯━┯━┯━┯━┯━┯━━━┯━┓ │ │P│: byte to be yielded by p +//! ┃A│A│A│A│C│D│B│B│P│P│...│P┃ │ │╳│: byte already yielded +//! ┗↥┷━┷━┷━┷↥┷↥┷↥┷━┷↥┷━┷━━━┷━┛ │ +//! a c d b p 🭭 +//! ``` +//! +//! This makes it possible to consume the stream at different places. This is particularly useful in +//! a multithreaded setting, in which we want to use the same generator from different independent +//! threads: +//! +//! ```ascii +//! 0 1 2 3 4 5 6 7 8 9 M │ +//! ┏━┯━┯━┯━┯━┯━┯━┯━┯━┯━┯━━━┯━┓ │ +//! ┃A│A│A│A│C│D│B│B│P│P│...│P┃ │ +//! ┗↥┷━┷━┷━┷↥┷↥┷↥┷━┷↥┷━┷━━━┷━┛ │ +//! a c d b p │ +//! │ +//! a.next_byte() │ +//! │ +//! 0 1 2 3 4 5 6 7 8 9 M │ +//! ┏━┯━┯━┯━┯━┯━┯━┯━┯━┯━┯━━━┯━┓ │ +//! ┃╳│A│A│A│C│D│B│B│P│P│...│P┃ │ +//! ┗━┷↥┷━┷━┷↥┷↥┷↥┷━┷↥┷━┷━━━┷━┛ │ +//! a c d b p │ +//! │ legend: +//! b.next_byte() │ ------- +//! │ ↥ : next byte to be yielded by p +//! 0 1 2 3 4 5 6 7 8 9 M │ p +//! ┏━┯━┯━┯━┯━┯━┯━┯━┯━┯━┯━━━┯━┓ │ │P│: byte to be yielded by p +//! ┃╳│A│A│A│C│D│╳│B│P│P│...│P┃ │ │╳│: byte already yielded +//! ┗━┷↥┷━┷━┷↥┷↥┷━┷↥┷↥┷━┷━━━┷━┛ │ +//! a c d b p 🭭 +//! ``` +//! +//! Implementation +//! ============== +//! //! The implementation is based on the AES blockcipher used in counter (CTR) mode, as presented //! in the ISO/IEC 18033-4 document. - -#[cfg(feature = "multithread")] -use rayon::prelude::*; -use std::fmt::{Debug, Display, Formatter}; - -mod aesni; -mod counter; -mod software; -use crate::counter::{ - AesKey, ByteCount, BytesPerChild, ChildrenCount, ForkError, HardAesCtrGenerator, - SoftAesCtrGenerator, -}; -pub use software::set_soft_rdseed_secret; - -/// The pseudorandom number generator. -/// -/// If the correct instructions sets are available on the machine, an hardware accelerated version -/// of the generator can be used. -#[allow(clippy::large_enum_variant)] -#[derive(Clone)] -pub enum RandomGenerator { - #[doc(hidden)] - Software(SoftAesCtrGenerator), - #[doc(hidden)] - Hardware(HardAesCtrGenerator), -} - -impl RandomGenerator { - /// Builds a new random generator, selecting the hardware implementation if available. - /// Optionally, a seed can be provided. - /// - /// # Note - /// - /// If using the `slow` feature, this function will return the non-accelerated variant, even - /// though the right instructions are available. - pub fn new(seed: Option) -> RandomGenerator { - if cfg!(feature = "slow") { - return RandomGenerator::new_software(seed); - } - RandomGenerator::new_hardware(seed).unwrap_or_else(|| RandomGenerator::new_software(seed)) - } - - /// Builds a new software random generator, optionally seeding it with a given value. - pub fn new_software(seed: Option) -> RandomGenerator { - RandomGenerator::Software(SoftAesCtrGenerator::new(seed.map(AesKey), None, None)) - } - - /// Tries to build a new hardware random generator, optionally seeding it with a given value. - pub fn new_hardware(seed: Option) -> Option { - if !is_x86_feature_detected!("aes") - || !is_x86_feature_detected!("rdseed") - || !is_x86_feature_detected!("sse2") - { - return None; - } - Some(RandomGenerator::Hardware(HardAesCtrGenerator::new( - seed.map(AesKey), - None, - None, - ))) - } - - /// Yields the next byte from the generator. - pub fn generate_next(&mut self) -> u8 { - match self { - Self::Hardware(ref mut rand) => rand.generate_next(), - Self::Software(ref mut rand) => rand.generate_next(), - } - } - - pub fn is_bounded(&self) -> bool { - match self { - Self::Hardware(ref rand) => rand.is_bounded(), - Self::Software(ref rand) => rand.is_bounded(), - } - } - - /// Returns the number of remaining bytes, if the generator is bounded. - pub fn remaining_bytes(&self) -> ByteCount { - match self { - Self::Hardware(rand) => rand.remaining_bytes(), - Self::Software(rand) => rand.remaining_bytes(), - } - } - - /// Tries to fork the current generator into `n_child` generators each able to yield - /// `child_bytes` random bytes. - /// - /// If the total number of bytes to be generated exceeds the bound of the current generator, - /// `None` is returned. Otherwise, we return an iterator over the children generators. - pub fn try_fork( - &mut self, - n_child: usize, - child_bytes: usize, - ) -> Result, ForkError> { - enum GeneratorChildIter - where - HardIter: Iterator, - SoftIter: Iterator, - { - Hardware(HardIter), - Software(SoftIter), - } - - impl Iterator for GeneratorChildIter - where - HardIter: Iterator, - SoftIter: Iterator, - { - type Item = RandomGenerator; - - fn next(&mut self) -> Option { - match self { - GeneratorChildIter::Hardware(ref mut iter) => { - iter.next().map(RandomGenerator::Hardware) - } - GeneratorChildIter::Software(ref mut iter) => { - iter.next().map(RandomGenerator::Software) - } - } - } - } - match self { - Self::Hardware(ref mut rand) => rand - .try_fork(ChildrenCount(n_child), BytesPerChild(child_bytes)) - .map(GeneratorChildIter::Hardware), - Self::Software(ref mut rand) => rand - .try_fork(ChildrenCount(n_child), BytesPerChild(child_bytes)) - .map(GeneratorChildIter::Software), - } - } - - /// Tries to fork the current generator into `n_child` generators each able to yield - /// `child_bytes` random bytes as a parallel iterator. - /// - /// If the total number of bytes to be generated exceeds the bound of the current generator, - /// `None` is returned. Otherwise, we return a parallel iterator over the children generators. - /// - /// # Notes - /// - /// This method necessitates the "multithread" feature. - #[cfg(feature = "multithread")] - pub fn par_try_fork( - &mut self, - n_child: usize, - child_bytes: usize, - ) -> Result, ForkError> { - use rayon::iter::plumbing::{Consumer, ProducerCallback, UnindexedConsumer}; - enum GeneratorChildIter - where - HardIter: IndexedParallelIterator + Send + Sync, - SoftIter: IndexedParallelIterator + Send + Sync, - { - Hardware(HardIter), - Software(SoftIter), - } - impl ParallelIterator for GeneratorChildIter - where - HardIter: IndexedParallelIterator + Send + Sync, - SoftIter: IndexedParallelIterator + Send + Sync, - { - type Item = RandomGenerator; - fn drive_unindexed(self, consumer: C) -> >::Result - where - C: UnindexedConsumer, - { - match self { - Self::Hardware(iter) => iter - .map(RandomGenerator::Hardware) - .drive_unindexed(consumer), - Self::Software(iter) => iter - .map(RandomGenerator::Software) - .drive_unindexed(consumer), - } - } - } - impl IndexedParallelIterator for GeneratorChildIter - where - HardIter: IndexedParallelIterator + Send + Sync, - SoftIter: IndexedParallelIterator + Send + Sync, - { - fn len(&self) -> usize { - match self { - Self::Software(iter) => iter.len(), - Self::Hardware(iter) => iter.len(), - } - } - fn drive>( - self, - consumer: C, - ) -> >::Result { - match self { - Self::Software(iter) => iter.map(RandomGenerator::Software).drive(consumer), - Self::Hardware(iter) => iter.map(RandomGenerator::Hardware).drive(consumer), - } - } - fn with_producer>( - self, - callback: CB, - ) -> >::Output { - match self { - Self::Software(iter) => { - iter.map(RandomGenerator::Software).with_producer(callback) - } - Self::Hardware(iter) => { - iter.map(RandomGenerator::Hardware).with_producer(callback) - } - } - } - } - - match self { - Self::Hardware(ref mut rand) => rand - .par_try_fork(ChildrenCount(n_child), BytesPerChild(child_bytes)) - .map(GeneratorChildIter::Hardware), - Self::Software(ref mut rand) => rand - .par_try_fork(ChildrenCount(n_child), BytesPerChild(child_bytes)) - .map(GeneratorChildIter::Software), - } - } -} - -impl Debug for RandomGenerator { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "RandomGenerator") - } -} - -impl Display for RandomGenerator { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "RandomGenerator") - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_uniformity() { - // Checks that the PRNG generates uniform numbers - let precision = 10f64.powi(-4); - let n_samples = 10_000_000_usize; - let mut generator = RandomGenerator::new(None); - let mut counts = [0usize; 256]; - let expected_prob: f64 = 1. / 256.; - for _ in 0..n_samples { - counts[generator.generate_next() as usize] += 1; - } - counts - .iter() - .map(|a| (*a as f64) / (n_samples as f64)) - .for_each(|a| assert!((a - expected_prob) < precision)) - } - - #[test] - fn test_generator_determinism() { - // Checks that given a state and a key, the PRNG is determinist. - for _ in 0..100 { - let key = software::dev_random(); - let mut first_generator = RandomGenerator::new(Some(key)); - let mut second_generator = RandomGenerator::new(Some(key)); - for _ in 0..128 { - assert_eq!( - first_generator.generate_next(), - second_generator.generate_next() - ); - } - } - } - - #[test] - fn test_fork() { - // Checks that forks returns a bounded child, and that the proper number of bytes can be - // generated. - let mut gen = RandomGenerator::new(None); - let mut bounded = gen.try_fork(1, 10).unwrap().next().unwrap(); - assert!(bounded.is_bounded()); - assert!(!gen.is_bounded()); - for _ in 0..10 { - bounded.generate_next(); - } - } - - #[test] - #[should_panic] - fn test_bounded_panic() { - // Checks that a bounded prng panics when exceeding the allowed number of bytes. - let mut gen = RandomGenerator::new(None); - let mut bounded = gen.try_fork(1, 10).unwrap().next().unwrap(); - assert!(bounded.is_bounded()); - assert!(!gen.is_bounded()); - for _ in 0..11 { - bounded.generate_next(); - } - } -} +pub mod generators; +pub mod seeders; diff --git a/concrete-csprng/src/generate_random.rs b/concrete-csprng/src/main.rs similarity index 54% rename from concrete-csprng/src/generate_random.rs rename to concrete-csprng/src/main.rs index e911ae8965..37aae4b949 100644 --- a/concrete-csprng/src/generate_random.rs +++ b/concrete-csprng/src/main.rs @@ -1,18 +1,20 @@ //! This program uses the concrete csprng to generate an infinite stream of random bytes on //! the program stdout. For testing purpose. +use concrete_csprng::generators::{AesniRandomGenerator, RandomGenerator}; +use concrete_csprng::seeders::{RdseedSeeder, Seeder}; use std::io::prelude::*; use std::io::stdout; -use concrete_csprng::RandomGenerator; - pub fn main() { - let mut generator = RandomGenerator::new(None); + let mut seeder = RdseedSeeder; + let mut generator = AesniRandomGenerator::new(seeder.seed()); let mut stdout = stdout(); let mut buffer = [0u8; 16]; loop { buffer .iter_mut() - .for_each(|a| *a = generator.generate_next()); + .zip(&mut generator) + .for_each(|(b, g)| *b = g); stdout.write_all(&buffer).unwrap(); } } diff --git a/concrete-csprng/src/seeders/implem/linux.rs b/concrete-csprng/src/seeders/implem/linux.rs new file mode 100644 index 0000000000..8408a989c6 --- /dev/null +++ b/concrete-csprng/src/seeders/implem/linux.rs @@ -0,0 +1,59 @@ +use crate::seeders::{Seed, Seeder}; +use std::fs::File; +use std::io::Read; + +/// A seeder which uses the linux `/dev/random` source. +pub struct LinuxSeeder { + counter: u128, + secret: u128, + file: File, +} + +impl LinuxSeeder { + /// Creates a new seeder from a user defined secret. + /// + /// Important: + /// ---------- + /// + /// This secret is used to ensure the quality of the seed in scenarios where `/dev/random` may + /// be compromised. + pub fn new(secret: u128) -> LinuxSeeder { + let file = std::fs::File::open("/dev/random").expect("Failed to open /dev/random ."); + let counter = std::time::UNIX_EPOCH + .elapsed() + .expect("Failed to initialized software rdseed counter.") + .as_nanos(); + LinuxSeeder { + secret, + counter, + file, + } + } +} + +impl Seeder for LinuxSeeder { + fn seed(&mut self) -> Seed { + let output = self.secret ^ self.counter ^ dev_random(&mut self.file); + self.counter = self.counter.wrapping_add(1); + Seed(output) + } +} + +fn dev_random(random: &mut File) -> u128 { + let mut buf = [0u8; 16]; + random + .read_exact(&mut buf[..]) + .expect("Failed to read from /dev/random ."); + u128::from_ne_bytes(buf) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::seeders::generic_tests::check_seeder_fixed_sequences_different; + + #[test] + fn check_bounded_sequence_difference() { + check_seeder_fixed_sequences_different(LinuxSeeder::new); + } +} diff --git a/concrete-csprng/src/seeders/implem/mod.rs b/concrete-csprng/src/seeders/implem/mod.rs new file mode 100644 index 0000000000..2b3174648e --- /dev/null +++ b/concrete-csprng/src/seeders/implem/mod.rs @@ -0,0 +1,9 @@ +#[cfg(feature = "seeder_x86_64_rdseed")] +mod rdseed; +#[cfg(feature = "seeder_x86_64_rdseed")] +pub use rdseed::RdseedSeeder; + +#[cfg(feature = "seeder_linux")] +mod linux; +#[cfg(feature = "seeder_linux")] +pub use linux::LinuxSeeder; diff --git a/concrete-csprng/src/seeders/implem/rdseed.rs b/concrete-csprng/src/seeders/implem/rdseed.rs new file mode 100644 index 0000000000..6c317a3ac9 --- /dev/null +++ b/concrete-csprng/src/seeders/implem/rdseed.rs @@ -0,0 +1,46 @@ +use crate::seeders::{Seed, Seeder}; + +/// A seeder which uses the `rdseed` x86_64 instruction. +/// +/// The `rdseed` instruction allows to deliver seeds from a hardware source of entropy see +/// . +pub struct RdseedSeeder; + +impl Seeder for RdseedSeeder { + fn seed(&mut self) -> Seed { + Seed(rdseed_random_m128()) + } +} + +// Generates a random 128 bits value from rdseed +fn rdseed_random_m128() -> u128 { + let mut rand1: u64 = 0; + let mut rand2: u64 = 0; + let mut output_bytes = [0u8; 16]; + unsafe { + loop { + if core::arch::x86_64::_rdseed64_step(&mut rand1) == 1 { + break; + } + } + loop { + if core::arch::x86_64::_rdseed64_step(&mut rand2) == 1 { + break; + } + } + } + output_bytes[0..8].copy_from_slice(&rand1.to_ne_bytes()); + output_bytes[8..16].copy_from_slice(&rand2.to_ne_bytes()); + u128::from_ne_bytes(output_bytes) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::seeders::generic_tests::check_seeder_fixed_sequences_different; + + #[test] + fn check_bounded_sequence_difference() { + check_seeder_fixed_sequences_different(|_| RdseedSeeder); + } +} diff --git a/concrete-csprng/src/seeders/mod.rs b/concrete-csprng/src/seeders/mod.rs new file mode 100644 index 0000000000..bcd4f0e79c --- /dev/null +++ b/concrete-csprng/src/seeders/mod.rs @@ -0,0 +1,43 @@ +//! A module containing seeders objects. +//! +//! When initializing a generator, one needs to provide a [`Seed`], which is then used as key to the +//! AES blockcipher. As a consequence, the quality of the outputs of the generator is directly +//! conditioned by the quality of this seed. This module proposes different mechanisms to deliver +//! seeds that can accomodate varying scenarios. + +/// A seed value, used to initialize a generator. +#[derive(Debug, Copy, Clone)] +pub struct Seed(pub u128); + +/// A trait representing a seeding strategy. +pub trait Seeder { + /// Generates a new seed. + fn seed(&mut self) -> Seed; +} + +mod implem; +pub use implem::*; + +#[cfg(test)] +mod generic_tests { + use crate::seeders::Seeder; + + /// Naively verifies that two fixed-size sequences generated by repeatedly calling the seeder + /// are different. + #[allow(unused)] // to please clippy when tests are not activated + pub fn check_seeder_fixed_sequences_different S>( + construct_seeder: F, + ) { + const SEQUENCE_SIZE: usize = 500; + const REPEATS: usize = 10_000; + let mut sequence = [0u128; SEQUENCE_SIZE * 2]; + for i in 0..REPEATS { + let mut seeder = construct_seeder(i as u128); + sequence.iter_mut().for_each(|v| *v = seeder.seed().0); + assert_ne!( + sequence[0..SEQUENCE_SIZE], + sequence[SEQUENCE_SIZE..SEQUENCE_SIZE * 2] + ); + } + } +} diff --git a/concrete-csprng/src/software.rs b/concrete-csprng/src/software.rs deleted file mode 100644 index 72c8103557..0000000000 --- a/concrete-csprng/src/software.rs +++ /dev/null @@ -1,193 +0,0 @@ -//! A module using a software fallback implementation of `aes128-counter` random number generator. -use crate::counter::{AesBatchedGenerator, AesIndex, AesKey}; -use aes_soft::cipher::generic_array::GenericArray; -use aes_soft::cipher::{BlockCipher, NewBlockCipher}; -use aes_soft::Aes128; -use std::cell::UnsafeCell; -use std::io::Read; - -thread_local! { - static RDSEED_COUNTER: UnsafeCell = UnsafeCell::new( - std::time::UNIX_EPOCH - .elapsed() - .expect("Failed to initialized software rdseed counter.") - .as_nanos() - ); - static RDSEED_SECRET: UnsafeCell = UnsafeCell::new(0); - static RDSEED_SEEDED: UnsafeCell = UnsafeCell::new(false); -} - -/// Sets the secret used to seed the software version of the prng. -/// -/// When using the software variant of the CSPRNG, we do not have access to the (trusted) -/// hardware source of randomness to seed the generator. Instead, we use a value from -/// `/dev/random`, which can be easy to temper with. To mitigate this risk, the user can provide -/// a secret value that is included in the seed of the prng. Note that to ensure maximal -/// security, this value should be different each time a new application using concrete is started. -pub fn set_soft_rdseed_secret(secret: u128) { - RDSEED_SECRET.with(|f| { - let _secret = unsafe { &mut *{ f.get() } }; - *_secret = secret; - }); - RDSEED_SEEDED.with(|f| { - let _seeded = unsafe { &mut *{ f.get() } }; - *_seeded = true; - }) -} - -fn rdseed() -> u128 { - RDSEED_SEEDED.with(|f| { - let is_seeded = unsafe { &*{ f.get() } }; - if !*is_seeded { - println!( - "WARNING: You are currently using the software variant of concrete-csprng \ - which does not have access to a hardware source of randomness. To ensure the \ - security of your application, please arrange to provide a secret by using the \ - `concrete_csprng::set_soft_rdseed_secret` function." - ); - } - }); - let mut output: u128 = 0; - RDSEED_SECRET.with(|f| { - let secret = unsafe { &*{ f.get() } }; - RDSEED_COUNTER.with(|f| { - let counter = unsafe { &mut *{ f.get() } }; - output = *secret ^ *counter ^ dev_random(); - *counter = counter.wrapping_add(1); - }) - }); - output -} - -#[derive(Clone)] -pub struct Generator { - // Aes structure - aes: Aes128, -} - -impl AesBatchedGenerator for Generator { - fn new(key: Option) -> Generator { - let key: [u8; 16] = key.map(|AesKey(k)| k).unwrap_or_else(rdseed).to_ne_bytes(); - let key = GenericArray::clone_from_slice(&key[..]); - let aes = Aes128::new(&key); - Generator { aes } - } - - fn generate_batch(&mut self, AesIndex(aes_ctr): AesIndex) -> [u8; 128] { - aes_encrypt_many( - aes_ctr, - aes_ctr + 1, - aes_ctr + 2, - aes_ctr + 3, - aes_ctr + 4, - aes_ctr + 5, - aes_ctr + 6, - aes_ctr + 7, - &self.aes, - ) - } -} - -pub fn dev_random() -> u128 { - let mut random = std::fs::File::open("/dev/random").expect("Failed to open /dev/random ."); - let mut buf = [0u8; 16]; - random - .read_exact(&mut buf[..]) - .expect("Failed to read from /dev/random ."); - u128::from_ne_bytes(buf) -} - -// Uses aes to encrypt many values at once. This allows a substantial speedup (around 30%) -// compared to the naive approach. -#[allow(clippy::too_many_arguments)] -fn aes_encrypt_many( - message_1: u128, - message_2: u128, - message_3: u128, - message_4: u128, - message_5: u128, - message_6: u128, - message_7: u128, - message_8: u128, - cipher: &Aes128, -) -> [u8; 128] { - let mut b1 = GenericArray::clone_from_slice(&message_1.to_ne_bytes()[..]); - let mut b2 = GenericArray::clone_from_slice(&message_2.to_ne_bytes()[..]); - let mut b3 = GenericArray::clone_from_slice(&message_3.to_ne_bytes()[..]); - let mut b4 = GenericArray::clone_from_slice(&message_4.to_ne_bytes()[..]); - let mut b5 = GenericArray::clone_from_slice(&message_5.to_ne_bytes()[..]); - let mut b6 = GenericArray::clone_from_slice(&message_6.to_ne_bytes()[..]); - let mut b7 = GenericArray::clone_from_slice(&message_7.to_ne_bytes()[..]); - let mut b8 = GenericArray::clone_from_slice(&message_8.to_ne_bytes()[..]); - - cipher.encrypt_block(&mut b1); - cipher.encrypt_block(&mut b2); - cipher.encrypt_block(&mut b3); - cipher.encrypt_block(&mut b4); - cipher.encrypt_block(&mut b5); - cipher.encrypt_block(&mut b6); - cipher.encrypt_block(&mut b7); - cipher.encrypt_block(&mut b8); - - let output_array: [[u8; 16]; 8] = [ - b1.into(), - b2.into(), - b3.into(), - b4.into(), - b5.into(), - b6.into(), - b7.into(), - b8.into(), - ]; - - unsafe { *{ output_array.as_ptr() as *const [u8; 128] } } -} - -#[cfg(test)] -mod test { - use super::*; - use std::convert::TryInto; - - // Test vector for aes128, from the FIPS publication 197 - const CIPHER_KEY: u128 = u128::from_be(0x000102030405060708090a0b0c0d0e0f); - const PLAINTEXT: u128 = u128::from_be(0x00112233445566778899aabbccddeeff); - const CIPHERTEXT: u128 = u128::from_be(0x69c4e0d86a7b0430d8cdb78070b4c55a); - - #[test] - fn test_encrypt_many_messages() { - // Checks that encrypting many plaintext at the same time gives the correct output. - let key: [u8; 16] = CIPHER_KEY.to_ne_bytes(); - let aes = Aes128::new(&GenericArray::from(key)); - let ciphertexts = aes_encrypt_many( - PLAINTEXT, PLAINTEXT, PLAINTEXT, PLAINTEXT, PLAINTEXT, PLAINTEXT, PLAINTEXT, PLAINTEXT, - &aes, - ); - let ciphertexts: [u8; 128] = ciphertexts[..].try_into().unwrap(); - for i in 0..8 { - assert_eq!( - u128::from_ne_bytes(ciphertexts[16 * i..16 * (i + 1)].try_into().unwrap()), - CIPHERTEXT - ); - } - } - - #[test] - fn test_uniformity() { - // Checks that the PRNG generates uniform numbers - let precision = 10f64.powi(-4); - let n_samples = 1_000_000_usize; - let mut generator = Generator::new(None); - let mut counts = [0usize; 256]; - let expected_prob: f64 = 1. / 256.; - for counter in 0..n_samples { - let batch = generator.generate_batch(AesIndex(counter as u128)); - for i in 0..128 { - counts[batch[i] as usize] += 1; - } - } - counts - .iter() - .map(|a| (*a as f64) / ((n_samples * 128) as f64)) - .for_each(|a| assert!((a - expected_prob) < precision)) - } -}