From fa372e9f5008350bd61289340019ce2222425a7c Mon Sep 17 00:00:00 2001 From: Frank Laub Date: Wed, 13 Nov 2024 11:19:22 -0800 Subject: [PATCH] Use risc0-bigint2 (#3) * Use risc0-bigint2 * Use num-bigint-dig feature * Update lockfile * Update ref * Update ref * Update git ref --- Cargo.lock | 16 +++++++++++++ Cargo.toml | 3 +++ src/algorithms/rsa.rs | 56 +++---------------------------------------- 3 files changed, 22 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c20e26fa..e7e53f7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -207,6 +207,12 @@ dependencies = [ "digest", ] +[[package]] +name = "include_bytes_aligned" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee796ad498c8d9a1d68e477df8f754ed784ef875de1414ebdaf169f70a6a784" + [[package]] name = "inout" version = "0.1.3" @@ -463,6 +469,15 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "risc0-bigint2" +version = "1.2.0-alpha.1" +source = "git+https://github.com/risc0/risc0?rev=8fc8437633f08a66e0fbacce947f41d01b074774#8fc8437633f08a66e0fbacce947f41d01b074774" +dependencies = [ + "include_bytes_aligned", + "num-bigint-dig", +] + [[package]] name = "rsa" version = "0.9.6" @@ -481,6 +496,7 @@ dependencies = [ "rand_chacha", "rand_core", "rand_xorshift", + "risc0-bigint2", "serde", "serde_test", "sha1", diff --git a/Cargo.toml b/Cargo.toml index e4b22655..67ed6d2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,9 @@ sha1 = { version = "0.10.5", optional = true, default-features = false, features sha2 = { version = "0.10.6", optional = true, default-features = false, features = ["oid"] } serde = { version = "1.0.184", optional = true, default-features = false, features = ["derive"] } +[target.'cfg(target_os = "zkvm")'.dependencies] +risc0-bigint2 = { git = "https://github.com/risc0/risc0", rev = "8fc8437633f08a66e0fbacce947f41d01b074774", default-features = false, features = ["num-bigint-dig"] } + [dev-dependencies] base64ct = { version = "1", features = ["alloc"] } hex-literal = "0.4.1" diff --git a/src/algorithms/rsa.rs b/src/algorithms/rsa.rs index dbe4941c..e2305220 100644 --- a/src/algorithms/rsa.rs +++ b/src/algorithms/rsa.rs @@ -11,56 +11,6 @@ use zeroize::{Zeroize, Zeroizing}; use crate::errors::{Error, Result}; use crate::traits::{PrivateKeyParts, PublicKeyParts}; -// The number of 32-bit words per element in the risc0 RSA syscalls -// Must match risc0_zkvm_platform::syscall::rsa::WIDTH_WORDS -#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] -const WIDTH_WORDS: usize = 96; - -#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] -extern "C" { - fn modpow_65537( - recv_buf: *mut [u32; WIDTH_WORDS], - in_base: *const [u32; WIDTH_WORDS], - in_modulus: *const [u32; WIDTH_WORDS], - ); -} - -#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] -/// Provides acceleration for ModPow (exponent 65537) in the RISC Zero zkVM -/// -/// Note that to use this, a dependency on `risc0-circuit-bigint` must be added -/// to the RISC Zero zkVM guest code calling this even if it is not otherwise -/// necessary. -fn risc0_modpow_65537(base: &BigUint, modulus: &BigUint) -> BigUint { - // Ensure inputs fill an even number of words - let mut base = base.to_bytes_le(); - if base.len() % 4 != 0 { - base.resize(base.len() + (4 - (base.len() % 4)), 0); - } - let mut modulus = modulus.to_bytes_le(); - if modulus.len() % 4 != 0 { - modulus.resize(modulus.len() + (4 - (modulus.len() % 4)), 0); - } - let base: [u32; WIDTH_WORDS] = base - .chunks(4) - .map(|word| u32::from_le_bytes(word.try_into().unwrap())) - .collect::>() - .try_into() - .unwrap(); - let modulus: [u32; WIDTH_WORDS] = modulus - .chunks(4) - .map(|word| u32::from_le_bytes(word.try_into().unwrap())) - .collect::>() - .try_into() - .unwrap(); - let mut result = [0u32; WIDTH_WORDS]; - // Safety: Parameters are dereferenceable & aligned - unsafe { - modpow_65537(&mut result, &base, &modulus); - } - return BigUint::from_slice(&result); -} - /// ⚠️ Raw RSA encryption of m with the public key. No padding is performed. /// /// # ☢️️ WARNING: HAZARDOUS API ☢️ @@ -69,11 +19,11 @@ fn risc0_modpow_65537(base: &BigUint, modulus: &BigUint) -> BigUint { /// or signature scheme. See the [module-level documentation][crate::hazmat] for more information. #[inline] pub fn rsa_encrypt(key: &K, m: &BigUint) -> Result { - #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + #[cfg(target_os = "zkvm")] { - // If we're in the RISC Zero zkVM, try to use its RSA accelerator circuit + // If we're in the RISC Zero zkVM, try to use an accelerated version. if *key.e() == BigUint::new(vec![65537]) { - return Ok(risc0_modpow_65537(m, key.n())); + return Ok(risc0_bigint2::rsa::modpow_65537(m, key.n())); } // Fall through when the exponent does not match the accelerator }