diff --git a/Cargo.lock b/Cargo.lock index 6301484..44516b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -48,6 +48,16 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloy-rlp" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f542548a609dca89fcd72b3b9f355928cf844d4363c5eed9c5273a3dd225e097" +dependencies = [ + "arrayvec", + "bytes", +] + [[package]] name = "anes" version = "0.1.6" @@ -103,6 +113,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "anyhow" +version = "1.0.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7" + [[package]] name = "ark-bn254" version = "0.4.0" @@ -110,8 +126,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a22f4561524cd949590d78d7d4c5df8f592430d221f7f3c9497bbafd8972120f" dependencies = [ "ark-ec", - "ark-ff", - "ark-std", + "ark-ff 0.4.1", + "ark-std 0.4.0", ] [[package]] @@ -123,12 +139,12 @@ dependencies = [ "ark-bn254", "ark-crypto-primitives", "ark-ec", - "ark-ff", + "ark-ff 0.4.1", "ark-groth16", "ark-poly", "ark-relations", - "ark-serialize", - "ark-std", + "ark-serialize 0.4.1", + "ark-std 0.4.0", "byteorder", "cfg-if", "color-eyre", @@ -150,14 +166,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3a13b34da09176a8baba701233fdffbaa7c1b1192ce031a3da4e55ce1f1a56" dependencies = [ "ark-ec", - "ark-ff", + "ark-ff 0.4.1", "ark-relations", - "ark-serialize", + "ark-serialize 0.4.1", "ark-snark", - "ark-std", + "ark-std 0.4.0", "blake2", "derivative", - "digest", + "digest 0.10.6", "rayon", "sha2", ] @@ -168,10 +184,10 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c60370a92f8e1a5f053cad73a862e1b99bc642333cd676fa11c0c39f80f4ac2" dependencies = [ - "ark-ff", + "ark-ff 0.4.1", "ark-poly", - "ark-serialize", - "ark-std", + "ark-serialize 0.4.1", + "ark-std 0.4.0", "derivative", "hashbrown 0.13.2", "itertools", @@ -180,27 +196,55 @@ dependencies = [ "zeroize", ] +[[package]] +name = "ark-ff" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b3235cc41ee7a12aaaf2c575a2ad7b46713a8a50bda2fc3b003a04845c05dd6" +dependencies = [ + "ark-ff-asm 0.3.0", + "ark-ff-macros 0.3.0", + "ark-serialize 0.3.0", + "ark-std 0.3.0", + "derivative", + "num-bigint", + "num-traits", + "paste", + "rustc_version 0.3.3", + "zeroize", +] + [[package]] name = "ark-ff" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c2d42532524bee1da5a4f6f733eb4907301baa480829557adcff5dfaeee1d9a" dependencies = [ - "ark-ff-asm", - "ark-ff-macros", - "ark-serialize", - "ark-std", + "ark-ff-asm 0.4.2", + "ark-ff-macros 0.4.2", + "ark-serialize 0.4.1", + "ark-std 0.4.0", "derivative", - "digest", + "digest 0.10.6", "itertools", "num-bigint", "num-traits", "paste", "rayon", - "rustc_version", + "rustc_version 0.4.0", "zeroize", ] +[[package]] +name = "ark-ff-asm" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db02d390bf6643fb404d3d22d31aee1c4bc4459600aef9113833d17e786c6e44" +dependencies = [ + "quote", + "syn 1.0.109", +] + [[package]] name = "ark-ff-asm" version = "0.4.2" @@ -211,6 +255,18 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ark-ff-macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fd794a08ccb318058009eefdf15bcaaaaf6f8161eb3345f907222bac38b20" +dependencies = [ + "num-bigint", + "num-traits", + "quote", + "syn 1.0.109", +] + [[package]] name = "ark-ff-macros" version = "0.4.2" @@ -232,11 +288,11 @@ checksum = "20ceafa83848c3e390f1cbf124bc3193b3e639b3f02009e0e290809a501b95fc" dependencies = [ "ark-crypto-primitives", "ark-ec", - "ark-ff", + "ark-ff 0.4.1", "ark-poly", "ark-relations", - "ark-serialize", - "ark-std", + "ark-serialize 0.4.1", + "ark-std 0.4.0", "rayon", ] @@ -246,9 +302,9 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f6ec811462cabe265cfe1b102fcfe3df79d7d2929c2425673648ee9abfd0272" dependencies = [ - "ark-ff", - "ark-serialize", - "ark-std", + "ark-ff 0.4.1", + "ark-serialize 0.4.1", + "ark-std 0.4.0", "derivative", "hashbrown 0.13.2", "rayon", @@ -260,12 +316,22 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00796b6efc05a3f48225e59cb6a2cda78881e7c390872d5786aaf112f31fb4f0" dependencies = [ - "ark-ff", - "ark-std", + "ark-ff 0.4.1", + "ark-std 0.4.0", "tracing", "tracing-subscriber 0.2.25", ] +[[package]] +name = "ark-serialize" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d6c2b318ee6e10f8c2853e73a83adc0ccb88995aa978d8a3408d492ab2ee671" +dependencies = [ + "ark-std 0.3.0", + "digest 0.9.0", +] + [[package]] name = "ark-serialize" version = "0.4.1" @@ -273,8 +339,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e7e735959bc173ea4baf13327b19c22d452b8e9e8e8f7b7fc34e6bf0e316c33e" dependencies = [ "ark-serialize-derive", - "ark-std", - "digest", + "ark-std 0.4.0", + "digest 0.10.6", "num-bigint", ] @@ -295,10 +361,20 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "84d3cc6833a335bb8a600241889ead68ee89a3cf8448081fb7694c0fe503da63" dependencies = [ - "ark-ff", + "ark-ff 0.4.1", "ark-relations", - "ark-serialize", - "ark-std", + "ark-serialize 0.4.1", + "ark-std 0.4.0", +] + +[[package]] +name = "ark-std" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1df2c09229cbc5a028b1d70e00fdb2acee28b1055dfb5ca73eea49c5a25c4e7c" +dependencies = [ + "num-traits", + "rand", ] [[package]] @@ -321,10 +397,10 @@ dependencies = [ "ark-bn254", "ark-circom", "ark-ec", - "ark-ff", + "ark-ff 0.4.1", "ark-groth16", "ark-relations", - "ark-serialize", + "ark-serialize 0.4.1", "color-eyre", "flame", "flamer", @@ -422,7 +498,7 @@ version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" dependencies = [ - "digest", + "digest 0.10.6", ] [[package]] @@ -981,6 +1057,15 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "digest" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +dependencies = [ + "generic-array", +] + [[package]] name = "digest" version = "0.10.6" @@ -1009,7 +1094,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0997c976637b606099b9985693efa3581e84e41f5c11ba5255f88711058ad428" dependencies = [ "der", - "digest", + "digest 0.10.6", "elliptic-curve", "rfc6979", "signature", @@ -1030,7 +1115,7 @@ checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" dependencies = [ "base16ct", "crypto-bigint", - "digest", + "digest 0.10.6", "ff", "generic-array", "group", @@ -1206,6 +1291,17 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" +[[package]] +name = "fastrlp" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139834ddba373bbdd213dffe02c8d110508dcf1726c2be27e8d1f7d7e1856418" +dependencies = [ + "arrayvec", + "auto_impl", + "bytes", +] + [[package]] name = "ff" version = "0.13.0" @@ -1422,7 +1518,7 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" dependencies = [ - "digest", + "digest 0.10.6", ] [[package]] @@ -1965,6 +2061,17 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" +[[package]] +name = "pest" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879952a81a83930934cbf1786752d6dedc3b1f29e8f8fb2ad1d0a36f377cf442" +dependencies = [ + "memchr", + "thiserror", + "ucd-trie", +] + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -2082,6 +2189,45 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proptest" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14cae93065090804185d3b75f0bf93b8eeda30c7a9b4a33d3bdb3988d6229e50" +dependencies = [ + "bitflags 2.4.1", + "lazy_static 1.4.0", + "num-traits", + "rand", + "rand_chacha", + "rand_xorshift", + "regex-syntax 0.8.5", + "unarray", +] + +[[package]] +name = "prost" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c0fef6c4230e4ccf618a35c59d7ede15dea37de8427500f50aff708806e42ec" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn 2.0.16", +] + [[package]] name = "ptr_meta" version = "0.1.4" @@ -2147,6 +2293,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_xorshift" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" +dependencies = [ + "rand_core", +] + [[package]] name = "rayon" version = "1.7.0" @@ -2212,7 +2367,7 @@ checksum = "af83e617f331cc6ae2da5443c602dfa5af81e517212d9d611a5b3ba1777b5370" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.7.1", ] [[package]] @@ -2221,6 +2376,12 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a5996294f19bd3aae0453a862ad728f60e6600695733dd5df01da90c54363a3c" +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + [[package]] name = "region" version = "3.0.0" @@ -2287,12 +2448,13 @@ dependencies = [ "ark-bn254", "ark-circom", "ark-ec", - "ark-ff", + "ark-ff 0.4.1", "ark-groth16", "ark-relations", - "ark-serialize", - "ark-std", + "ark-serialize 0.4.1", + "ark-std 0.4.0", "ark-zkey", + "byteorder", "cfg-if", "color-eyre", "criterion 0.4.0", @@ -2301,8 +2463,10 @@ dependencies = [ "num-bigint", "num-traits", "once_cell", + "prost", "rand", "rand_chacha", + "ruint", "serde", "serde_json", "sled", @@ -2365,6 +2529,35 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ruint" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95294d6e3a6192f3aabf91c38f56505a625aa495533442744185a36d75a790c4" +dependencies = [ + "alloy-rlp", + "ark-ff 0.3.0", + "ark-ff 0.4.1", + "bytes", + "fastrlp", + "num-bigint", + "parity-scale-codec", + "primitive-types", + "proptest", + "rand", + "rlp", + "ruint-macro", + "serde", + "valuable", + "zeroize", +] + +[[package]] +name = "ruint-macro" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48fd7bd8a6377e15ad9d42a8ec25371b94ddc67abe7c8b9127bec79bebaaae18" + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -2383,13 +2576,22 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e75f6a532d0fd9f7f13144f392b6ad56a32696bfcd9c78f797f16bbb6f072d6" +[[package]] +name = "rustc_version" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0dfe2087c51c460008730de8b57e6a320782fbfb312e1f4d520e6c6fae155ee" +dependencies = [ + "semver 0.11.0", +] + [[package]] name = "rustc_version" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" dependencies = [ - "semver", + "semver 1.0.17", ] [[package]] @@ -2496,12 +2698,30 @@ dependencies = [ "zeroize", ] +[[package]] +name = "semver" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6" +dependencies = [ + "semver-parser", +] + [[package]] name = "semver" version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" +[[package]] +name = "semver-parser" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9900206b54a3527fdc7b8a938bffd94a568bac4f4aa8113b209df75a09c0dec2" +dependencies = [ + "pest", +] + [[package]] name = "serde" version = "1.0.163" @@ -2571,7 +2791,7 @@ checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" dependencies = [ "cfg-if", "cpufeatures", - "digest", + "digest 0.10.6", ] [[package]] @@ -2580,7 +2800,7 @@ version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" dependencies = [ - "digest", + "digest 0.10.6", "keccak", ] @@ -2599,7 +2819,7 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e1788eed21689f9cf370582dfc467ef36ed9c707f073528ddafa8d83e3b8500" dependencies = [ - "digest", + "digest 0.10.6", "rand_core", ] @@ -2922,6 +3142,12 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +[[package]] +name = "ucd-trie" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" + [[package]] name = "uint" version = "0.9.5" @@ -2934,6 +3160,12 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicode-ident" version = "1.0.8" @@ -3562,7 +3794,7 @@ name = "zerokit_utils" version = "0.5.1" dependencies = [ "ark-bn254", - "ark-ff", + "ark-ff 0.4.1", "color-eyre", "criterion 0.4.0", "hex", diff --git a/README.md b/README.md index 1b74561..b6941f8 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,8 @@ in Rust. - [semaphore-rs](https://github.com/worldcoin/semaphore-rs) written in Rust based on ark-circom. +- The circom witness calculation code of the rln crate is based on [circom-witnesscalc](https://github.com/iden3/circom-witnesscalc) by iden3. The execution graph file used by this code has been generated by means of the same iden3 software. + ## Users Zerokit is used by - diff --git a/rln-cli/Cargo.toml b/rln-cli/Cargo.toml index 86f168e..9d1db53 100644 --- a/rln-cli/Cargo.toml +++ b/rln-cli/Cargo.toml @@ -11,3 +11,6 @@ color-eyre = "=0.6.2" # serialization serde_json = "1.0.48" serde = { version = "1.0.130", features = ["derive"] } + +[features] +arkzkey = [] diff --git a/rln-cli/src/main.rs b/rln-cli/src/main.rs index c8931ff..dd7563e 100644 --- a/rln-cli/src/main.rs +++ b/rln-cli/src/main.rs @@ -27,7 +27,7 @@ fn main() -> Result<()> { tree_height, config, }) => { - let resources = File::open(&config)?; + let resources = File::open(config)?; state.rln = Some(RLN::new(*tree_height, resources)?); Ok(()) } @@ -38,9 +38,9 @@ fn main() -> Result<()> { }) => { let mut resources: Vec> = Vec::new(); #[cfg(feature = "arkzkey")] - let filenames = ["rln.wasm", "rln_final.arkzkey", "verification_key.arkvkey"]; + let filenames = ["rln_final.arkzkey", "verification_key.arkvkey"]; #[cfg(not(feature = "arkzkey"))] - let filenames = ["rln.wasm", "rln_final.zkey", "verification_key.arkvkey"]; + let filenames = ["rln_final.zkey", "verification_key.arkvkey"]; for filename in filenames { let fullpath = config.join(Path::new(filename)); let mut file = File::open(&fullpath)?; @@ -49,12 +49,11 @@ fn main() -> Result<()> { file.read_exact(&mut buffer)?; resources.push(buffer); } - let tree_config_input_file = File::open(&tree_config_input)?; + let tree_config_input_file = File::open(tree_config_input)?; state.rln = Some(RLN::new_with_params( *tree_height, resources[0].clone(), resources[1].clone(), - resources[2].clone(), tree_config_input_file, )?); Ok(()) @@ -67,7 +66,7 @@ fn main() -> Result<()> { Ok(()) } Some(Commands::SetLeaf { index, file }) => { - let input_data = File::open(&file)?; + let input_data = File::open(file)?; state .rln .ok_or(Report::msg("no RLN instance initialized"))? @@ -75,7 +74,7 @@ fn main() -> Result<()> { Ok(()) } Some(Commands::SetMultipleLeaves { index, file }) => { - let input_data = File::open(&file)?; + let input_data = File::open(file)?; state .rln .ok_or(Report::msg("no RLN instance initialized"))? @@ -83,7 +82,7 @@ fn main() -> Result<()> { Ok(()) } Some(Commands::ResetMultipleLeaves { file }) => { - let input_data = File::open(&file)?; + let input_data = File::open(file)?; state .rln .ok_or(Report::msg("no RLN instance initialized"))? @@ -91,7 +90,7 @@ fn main() -> Result<()> { Ok(()) } Some(Commands::SetNextLeaf { file }) => { - let input_data = File::open(&file)?; + let input_data = File::open(file)?; state .rln .ok_or(Report::msg("no RLN instance initialized"))? @@ -122,7 +121,7 @@ fn main() -> Result<()> { Ok(()) } Some(Commands::Prove { input }) => { - let input_data = File::open(&input)?; + let input_data = File::open(input)?; let writer = std::io::stdout(); state .rln @@ -131,7 +130,7 @@ fn main() -> Result<()> { Ok(()) } Some(Commands::Verify { file }) => { - let input_data = File::open(&file)?; + let input_data = File::open(file)?; state .rln .ok_or(Report::msg("no RLN instance initialized"))? @@ -139,7 +138,7 @@ fn main() -> Result<()> { Ok(()) } Some(Commands::GenerateProof { input }) => { - let input_data = File::open(&input)?; + let input_data = File::open(input)?; let writer = std::io::stdout(); state .rln @@ -148,8 +147,8 @@ fn main() -> Result<()> { Ok(()) } Some(Commands::VerifyWithRoots { input, roots }) => { - let input_data = File::open(&input)?; - let roots_data = File::open(&roots)?; + let input_data = File::open(input)?; + let roots_data = File::open(roots)?; state .rln .ok_or(Report::msg("no RLN instance initialized"))? diff --git a/rln/Cargo.toml b/rln/Cargo.toml index c24461f..6288aca 100644 --- a/rln/Cargo.toml +++ b/rln/Cargo.toml @@ -42,6 +42,7 @@ color-eyre = "=0.6.2" thiserror = "=1.0.39" # utilities +byteorder = "1.4.3" cfg-if = "=1.0" num-bigint = { version = "=0.4.3", default-features = false, features = [ "rand", @@ -51,11 +52,13 @@ once_cell = "=1.17.1" lazy_static = "=1.4.0" rand = "=0.8.5" rand_chacha = "=0.3.1" +ruint = { version = "1.10.0", features = ["rand", "serde", "ark-ff-04", "num-bigint"] } tiny-keccak = { version = "=2.0.2", features = ["keccak"] } utils = { package = "zerokit_utils", version = "=0.5.1", path = "../utils/", default-features = false } # serialization +prost = "0.13.1" serde_json = "=1.0.96" serde = { version = "=1.0.163", features = ["derive"] } diff --git a/rln/Makefile.toml b/rln/Makefile.toml index 9c7aba1..e353c14 100644 --- a/rln/Makefile.toml +++ b/rln/Makefile.toml @@ -4,7 +4,7 @@ args = ["build", "--release"] [tasks.test_default] command = "cargo" -args = ["test", "--release"] +args = ["test", "--release", "--", "--nocapture"] [tasks.test_stateless] command = "cargo" diff --git a/rln/README.md b/rln/README.md index ae38ee4..77aaf37 100644 --- a/rln/README.md +++ b/rln/README.md @@ -12,14 +12,18 @@ git clone https://github.com/vacp2p/zerokit.git cd zerokit/rln ``` -### Build and Test + ### Build and Test -To build and test, run the following commands within the module folder + To build and test, run the following commands within the module folder -```bash -cargo make build -cargo make test +``` bash + cargo make build + cargo make test_{mode} ``` +The {mode} placeholder should be replaced with +* **default** for the default tests; +* **arkzkey** for the tests with the arkzkey feature; +* **stateless** for the tests with the stateless feature. ### Compile ZK circuits diff --git a/rln/resources/tree_height_20/graph.bin b/rln/resources/tree_height_20/graph.bin new file mode 100644 index 0000000..1e335b7 Binary files /dev/null and b/rln/resources/tree_height_20/graph.bin differ diff --git a/rln/src/circuit.rs b/rln/src/circuit.rs index 9cc2d7f..caf7065 100644 --- a/rln/src/circuit.rs +++ b/rln/src/circuit.rs @@ -1,5 +1,6 @@ // This crate provides interfaces for the zero-knowledge circuit and keys +use crate::iden3calc::calc_witness; use ark_bn254::{ Bn254, Fq as ArkFq, Fq2 as ArkFq2, Fr as ArkFr, G1Affine as ArkG1Affine, G1Projective as ArkG1Projective, G2Affine as ArkG2Affine, G2Projective as ArkG2Projective, @@ -9,14 +10,10 @@ use ark_relations::r1cs::ConstraintMatrices; use ark_serialize::CanonicalDeserialize; use cfg_if::cfg_if; use color_eyre::{Report, Result}; +use num_bigint::BigInt; #[cfg(not(target_arch = "wasm32"))] -use { - ark_circom::WitnessCalculator, - lazy_static::lazy_static, - std::sync::{Arc, Mutex}, - wasmer::{Module, Store}, -}; +use ::lazy_static::lazy_static; #[cfg(feature = "arkzkey")] use { @@ -35,7 +32,7 @@ pub const ARKZKEY_BYTES_UNCOMPR: &[u8] = pub const ZKEY_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/rln_final.zkey"); pub const VK_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/verification_key.arkvkey"); -const WASM_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/rln.wasm"); +const GRAPH_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/graph.bin"); #[cfg(not(target_arch = "wasm32"))] lazy_static! { @@ -53,11 +50,6 @@ lazy_static! { #[cfg(not(target_arch = "wasm32"))] static ref VK: VerifyingKey = vk_from_ark_serialized(VK_BYTES).expect("Failed to read vk"); - - #[cfg(not(target_arch = "wasm32"))] - static ref WITNESS_CALCULATOR: Arc> = { - circom_from_raw(WASM_BYTES).expect("Failed to create witness calculator") - }; } pub const TEST_TREE_HEIGHT: usize = 20; @@ -92,6 +84,10 @@ pub fn zkey_from_raw(zkey_data: &[u8]) -> Result<(ProvingKey, ConstraintM Ok(proving_key_and_matrices) } +pub fn calculate_rln_witness)>>(inputs: I) -> Vec { + calc_witness(inputs, GRAPH_BYTES) +} + // Loads the proving key #[cfg(not(target_arch = "wasm32"))] pub fn zkey_from_folder() -> &'static (ProvingKey, ConstraintMatrices) { @@ -118,20 +114,6 @@ pub fn vk_from_folder() -> &'static VerifyingKey { &VK } -// Initializes the witness calculator using a bytes vector -#[cfg(not(target_arch = "wasm32"))] -pub fn circom_from_raw(wasm_buffer: &[u8]) -> Result>> { - let module = Module::new(&Store::default(), wasm_buffer)?; - let result = WitnessCalculator::from_module(module)?; - Ok(Arc::new(Mutex::new(result))) -} - -// Initializes the witness calculator -#[cfg(not(target_arch = "wasm32"))] -pub fn circom_from_folder() -> &'static Arc> { - &WITNESS_CALCULATOR -} - // Computes the verification key from a bytes vector containing pre-processed ark-serialized verification key // uncompressed, unchecked pub fn vk_from_ark_serialized(data: &[u8]) -> Result> { diff --git a/rln/src/ffi.rs b/rln/src/ffi.rs index 0d2cd70..da1660a 100644 --- a/rln/src/ffi.rs +++ b/rln/src/ffi.rs @@ -228,7 +228,6 @@ pub extern "C" fn new(ctx: *mut *mut RLN) -> bool { #[no_mangle] pub extern "C" fn new_with_params( tree_height: usize, - circom_buffer: *const Buffer, zkey_buffer: *const Buffer, vk_buffer: *const Buffer, tree_config: *const Buffer, @@ -236,7 +235,6 @@ pub extern "C" fn new_with_params( ) -> bool { match RLN::new_with_params( tree_height, - circom_buffer.process().to_vec(), zkey_buffer.process().to_vec(), vk_buffer.process().to_vec(), tree_config.process(), @@ -256,16 +254,11 @@ pub extern "C" fn new_with_params( #[cfg(feature = "stateless")] #[no_mangle] pub extern "C" fn new_with_params( - circom_buffer: *const Buffer, zkey_buffer: *const Buffer, vk_buffer: *const Buffer, ctx: *mut *mut RLN, ) -> bool { - match RLN::new_with_params( - circom_buffer.process().to_vec(), - zkey_buffer.process().to_vec(), - vk_buffer.process().to_vec(), - ) { + match RLN::new_with_params(zkey_buffer.process().to_vec(), vk_buffer.process().to_vec()) { Ok(rln) => { unsafe { *ctx = Box::into_raw(Box::new(rln)) }; true diff --git a/rln/src/iden3calc.rs b/rln/src/iden3calc.rs new file mode 100644 index 0000000..056582d --- /dev/null +++ b/rln/src/iden3calc.rs @@ -0,0 +1,73 @@ +// This file is based on the code by iden3. Its preimage can be found here: +// https://github.com/iden3/circom-witnesscalc/blob/5cb365b6e4d9052ecc69d4567fcf5bc061c20e94/src/lib.rs + +pub mod graph; +pub mod proto; +pub mod storage; + +use ark_bn254::Fr; +use graph::Node; +use num_bigint::BigInt; +use ruint::aliases::U256; +use std::collections::HashMap; +use storage::deserialize_witnesscalc_graph; + +pub type InputSignalsInfo = HashMap; + +pub fn calc_witness)>>( + inputs: I, + graph_data: &[u8], +) -> Vec { + let inputs: HashMap> = inputs + .into_iter() + .map(|(key, value)| (key, value.iter().map(|v| U256::from(v)).collect())) + .collect(); + + let (nodes, signals, input_mapping): (Vec, Vec, InputSignalsInfo) = + deserialize_witnesscalc_graph(std::io::Cursor::new(graph_data)).unwrap(); + + let mut inputs_buffer = get_inputs_buffer(get_inputs_size(&nodes)); + populate_inputs(&inputs, &input_mapping, &mut inputs_buffer); + + graph::evaluate(&nodes, inputs_buffer.as_slice(), &signals) +} + +fn get_inputs_size(nodes: &[Node]) -> usize { + let mut start = false; + let mut max_index = 0usize; + for &node in nodes.iter() { + if let Node::Input(i) = node { + if i > max_index { + max_index = i; + } + start = true + } else if start { + break; + } + } + max_index + 1 +} + +fn populate_inputs( + input_list: &HashMap>, + inputs_info: &InputSignalsInfo, + input_buffer: &mut [U256], +) { + for (key, value) in input_list { + let (offset, len) = inputs_info[key]; + if len != value.len() { + panic!("Invalid input length for {}", key); + } + + for (i, v) in value.iter().enumerate() { + input_buffer[offset + i] = *v; + } + } +} + +/// Allocates inputs vec with position 0 set to 1 +fn get_inputs_buffer(size: usize) -> Vec { + let mut inputs = vec![U256::ZERO; size]; + inputs[0] = U256::from(1); + inputs +} diff --git a/rln/src/iden3calc/graph.rs b/rln/src/iden3calc/graph.rs new file mode 100644 index 0000000..ab66b8e --- /dev/null +++ b/rln/src/iden3calc/graph.rs @@ -0,0 +1,947 @@ +// This file is based on the code by iden3. Its preimage can be found here: +// https://github.com/iden3/circom-witnesscalc/blob/5cb365b6e4d9052ecc69d4567fcf5bc061c20e94/src/graph.rs + +use crate::iden3calc::proto; +use ark_bn254::Fr; +use ark_ff::{BigInt, BigInteger, One, PrimeField, Zero}; +use rand::Rng; +use ruint::aliases::U256; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::error::Error; +use std::ops::{BitOr, BitXor, Deref}; +use std::{ + collections::HashMap, + ops::{BitAnd, Shl, Shr}, +}; + +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; +use ruint::uint; + +pub const M: U256 = + uint!(21888242871839275222246405745257275088548364400416034343698204186575808495617_U256); + +fn ark_se(a: &A, s: S) -> Result +where + S: serde::Serializer, +{ + let mut bytes = vec![]; + a.serialize_with_mode(&mut bytes, Compress::Yes) + .map_err(serde::ser::Error::custom)?; + s.serialize_bytes(&bytes) +} + +fn ark_de<'de, D, A: CanonicalDeserialize>(data: D) -> Result +where + D: serde::de::Deserializer<'de>, +{ + let s: Vec = serde::de::Deserialize::deserialize(data)?; + let a = A::deserialize_with_mode(s.as_slice(), Compress::Yes, Validate::Yes); + a.map_err(serde::de::Error::custom) +} + +#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize)] +pub enum Operation { + Mul, + Div, + Add, + Sub, + Pow, + Idiv, + Mod, + Eq, + Neq, + Lt, + Gt, + Leq, + Geq, + Land, + Lor, + Shl, + Shr, + Bor, + Band, + Bxor, +} + +impl Operation { + // TODO: rewrite to &U256 type + pub fn eval(&self, a: U256, b: U256) -> U256 { + use Operation::*; + match self { + Mul => a.mul_mod(b, M), + Div => { + if b == U256::ZERO { + // as we are simulating a circuit execution with signals + // values all equal to 0, just return 0 here in case of + // division by zero + U256::ZERO + } else { + a.mul_mod(b.inv_mod(M).unwrap(), M) + } + } + Add => a.add_mod(b, M), + Sub => a.add_mod(M - b, M), + Pow => a.pow_mod(b, M), + Mod => a.div_rem(b).1, + Eq => U256::from(a == b), + Neq => U256::from(a != b), + Lt => u_lt(&a, &b), + Gt => u_gt(&a, &b), + Leq => u_lte(&a, &b), + Geq => u_gte(&a, &b), + Land => U256::from(a != U256::ZERO && b != U256::ZERO), + Lor => U256::from(a != U256::ZERO || b != U256::ZERO), + Shl => compute_shl_uint(a, b), + Shr => compute_shr_uint(a, b), + // TODO test with conner case when it is possible to get the number + // bigger then modulus + Bor => a.bitor(b), + Band => a.bitand(b), + // TODO test with conner case when it is possible to get the number + // bigger then modulus + Bxor => a.bitxor(b), + Idiv => a / b, + } + } + + pub fn eval_fr(&self, a: Fr, b: Fr) -> Fr { + use Operation::*; + match self { + Mul => a * b, + // We always should return something on the circuit execution. + // So in case of division by 0 we would return 0. And the proof + // should be invalid in the end. + Div => { + if b.is_zero() { + Fr::zero() + } else { + a / b + } + } + Add => a + b, + Sub => a - b, + Idiv => { + if b.is_zero() { + Fr::zero() + } else { + Fr::new((Into::::into(a) / Into::::into(b)).into()) + } + } + Mod => { + if b.is_zero() { + Fr::zero() + } else { + Fr::new((Into::::into(a) % Into::::into(b)).into()) + } + } + Eq => match a.cmp(&b) { + Ordering::Equal => Fr::one(), + _ => Fr::zero(), + }, + Neq => match a.cmp(&b) { + Ordering::Equal => Fr::zero(), + _ => Fr::one(), + }, + Lt => Fr::new(u_lt(&a.into(), &b.into()).into()), + Gt => Fr::new(u_gt(&a.into(), &b.into()).into()), + Leq => Fr::new(u_lte(&a.into(), &b.into()).into()), + Geq => Fr::new(u_gte(&a.into(), &b.into()).into()), + Land => { + if a.is_zero() || b.is_zero() { + Fr::zero() + } else { + Fr::one() + } + } + Lor => { + if a.is_zero() && b.is_zero() { + Fr::zero() + } else { + Fr::one() + } + } + Shl => shl(a, b), + Shr => shr(a, b), + Bor => bit_or(a, b), + Band => bit_and(a, b), + Bxor => bit_xor(a, b), + // TODO implement other operators + _ => unimplemented!("operator {:?} not implemented for Montgomery", self), + } + } +} + +impl From<&Operation> for proto::DuoOp { + fn from(v: &Operation) -> Self { + match v { + Operation::Mul => proto::DuoOp::Mul, + Operation::Div => proto::DuoOp::Div, + Operation::Add => proto::DuoOp::Add, + Operation::Sub => proto::DuoOp::Sub, + Operation::Pow => proto::DuoOp::Pow, + Operation::Idiv => proto::DuoOp::Idiv, + Operation::Mod => proto::DuoOp::Mod, + Operation::Eq => proto::DuoOp::Eq, + Operation::Neq => proto::DuoOp::Neq, + Operation::Lt => proto::DuoOp::Lt, + Operation::Gt => proto::DuoOp::Gt, + Operation::Leq => proto::DuoOp::Leq, + Operation::Geq => proto::DuoOp::Geq, + Operation::Land => proto::DuoOp::Land, + Operation::Lor => proto::DuoOp::Lor, + Operation::Shl => proto::DuoOp::Shl, + Operation::Shr => proto::DuoOp::Shr, + Operation::Bor => proto::DuoOp::Bor, + Operation::Band => proto::DuoOp::Band, + Operation::Bxor => proto::DuoOp::Bxor, + } + } +} + +#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize)] +pub enum UnoOperation { + Neg, + Id, // identity - just return self +} + +impl UnoOperation { + pub fn eval(&self, a: U256) -> U256 { + match self { + UnoOperation::Neg => { + if a == U256::ZERO { + U256::ZERO + } else { + M - a + } + } + UnoOperation::Id => a, + } + } + + pub fn eval_fr(&self, a: Fr) -> Fr { + match self { + UnoOperation::Neg => { + if a.is_zero() { + Fr::zero() + } else { + let mut x = Fr::MODULUS; + x.sub_with_borrow(&a.into_bigint()); + Fr::from_bigint(x).unwrap() + } + } + _ => unimplemented!("uno operator {:?} not implemented for Montgomery", self), + } + } +} + +impl From<&UnoOperation> for proto::UnoOp { + fn from(v: &UnoOperation) -> Self { + match v { + UnoOperation::Neg => proto::UnoOp::Neg, + UnoOperation::Id => proto::UnoOp::Id, + } + } +} + +#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize)] +pub enum TresOperation { + TernCond, +} + +impl TresOperation { + pub fn eval(&self, a: U256, b: U256, c: U256) -> U256 { + match self { + TresOperation::TernCond => { + if a == U256::ZERO { + c + } else { + b + } + } + } + } + + pub fn eval_fr(&self, a: Fr, b: Fr, c: Fr) -> Fr { + match self { + TresOperation::TernCond => { + if a.is_zero() { + c + } else { + b + } + } + } + } +} + +impl From<&TresOperation> for proto::TresOp { + fn from(v: &TresOperation) -> Self { + match v { + TresOperation::TernCond => proto::TresOp::TernCond, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum Node { + Input(usize), + Constant(U256), + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + MontConstant(Fr), + UnoOp(UnoOperation, usize), + Op(Operation, usize, usize), + TresOp(TresOperation, usize, usize, usize), +} + +// TODO remove pub from Vec +#[derive(Default)] +pub struct Nodes(pub Vec); + +impl Nodes { + pub fn new() -> Self { + Nodes(Vec::new()) + } + + pub fn to_const(&self, idx: NodeIdx) -> Result { + let me = self.0.get(idx.0).ok_or(NodeConstErr::EmptyNode(idx))?; + match me { + Node::Constant(v) => Ok(*v), + Node::UnoOp(op, a) => Ok(op.eval(self.to_const(NodeIdx(*a))?)), + Node::Op(op, a, b) => { + Ok(op.eval(self.to_const(NodeIdx(*a))?, self.to_const(NodeIdx(*b))?)) + } + Node::TresOp(op, a, b, c) => Ok(op.eval( + self.to_const(NodeIdx(*a))?, + self.to_const(NodeIdx(*b))?, + self.to_const(NodeIdx(*c))?, + )), + Node::Input(_) => Err(NodeConstErr::InputSignal), + Node::MontConstant(_) => { + panic!("MontConstant should not be used here") + } + } + } + + pub fn push(&mut self, n: Node) -> NodeIdx { + self.0.push(n); + NodeIdx(self.0.len() - 1) + } + + pub fn get(&self, idx: NodeIdx) -> Option<&Node> { + self.0.get(idx.0) + } +} + +impl Deref for Nodes { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Copy, Clone)] +pub struct NodeIdx(pub usize); + +impl From for NodeIdx { + fn from(v: usize) -> Self { + NodeIdx(v) + } +} + +#[derive(Debug)] +pub enum NodeConstErr { + EmptyNode(NodeIdx), + InputSignal, +} + +impl std::fmt::Display for NodeConstErr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NodeConstErr::EmptyNode(idx) => { + write!(f, "empty node at index {}", idx.0) + } + NodeConstErr::InputSignal => { + write!(f, "input signal is not a constant") + } + } + } +} + +impl Error for NodeConstErr {} + +fn compute_shl_uint(a: U256, b: U256) -> U256 { + debug_assert!(b.lt(&U256::from(256))); + let ls_limb = b.as_limbs()[0]; + a.shl(ls_limb as usize) +} + +fn compute_shr_uint(a: U256, b: U256) -> U256 { + debug_assert!(b.lt(&U256::from(256))); + let ls_limb = b.as_limbs()[0]; + a.shr(ls_limb as usize) +} + +/// All references must be backwards. +fn assert_valid(nodes: &[Node]) { + for (i, &node) in nodes.iter().enumerate() { + if let Node::Op(_, a, b) = node { + assert!(a < i); + assert!(b < i); + } else if let Node::UnoOp(_, a) = node { + assert!(a < i); + } else if let Node::TresOp(_, a, b, c) = node { + assert!(a < i); + assert!(b < i); + assert!(c < i); + } + } +} + +pub fn optimize(nodes: &mut Vec, outputs: &mut [usize]) { + tree_shake(nodes, outputs); + propagate(nodes); + value_numbering(nodes, outputs); + constants(nodes); + tree_shake(nodes, outputs); + montgomery_form(nodes); +} + +pub fn evaluate(nodes: &[Node], inputs: &[U256], outputs: &[usize]) -> Vec { + // assert_valid(nodes); + + // Evaluate the graph. + let mut values = Vec::with_capacity(nodes.len()); + for &node in nodes.iter() { + let value = match node { + Node::Constant(c) => Fr::new(c.into()), + Node::MontConstant(c) => c, + Node::Input(i) => Fr::new(inputs[i].into()), + Node::Op(op, a, b) => op.eval_fr(values[a], values[b]), + Node::UnoOp(op, a) => op.eval_fr(values[a]), + Node::TresOp(op, a, b, c) => op.eval_fr(values[a], values[b], values[c]), + }; + values.push(value); + } + + // Convert from Montgomery form and return the outputs. + let mut out = vec![Fr::from(0); outputs.len()]; + for i in 0..outputs.len() { + out[i] = values[outputs[i]]; + } + + out +} + +/// Constant propagation +pub fn propagate(nodes: &mut [Node]) { + assert_valid(nodes); + for i in 0..nodes.len() { + if let Node::Op(op, a, b) = nodes[i] { + if let (Node::Constant(va), Node::Constant(vb)) = (nodes[a], nodes[b]) { + nodes[i] = Node::Constant(op.eval(va, vb)); + } else if a == b { + // Not constant but equal + use Operation::*; + if let Some(c) = match op { + Eq | Leq | Geq => Some(true), + Neq | Lt | Gt => Some(false), + _ => None, + } { + nodes[i] = Node::Constant(U256::from(c)); + } + } + } else if let Node::UnoOp(op, a) = nodes[i] { + if let Node::Constant(va) = nodes[a] { + nodes[i] = Node::Constant(op.eval(va)); + } + } else if let Node::TresOp(op, a, b, c) = nodes[i] { + if let (Node::Constant(va), Node::Constant(vb), Node::Constant(vc)) = + (nodes[a], nodes[b], nodes[c]) + { + nodes[i] = Node::Constant(op.eval(va, vb, vc)); + } + } + } +} + +/// Remove unused nodes +pub fn tree_shake(nodes: &mut Vec, outputs: &mut [usize]) { + assert_valid(nodes); + + // Mark all nodes that are used. + let mut used = vec![false; nodes.len()]; + for &i in outputs.iter() { + used[i] = true; + } + + // Work backwards from end as all references are backwards. + for i in (0..nodes.len()).rev() { + if used[i] { + if let Node::Op(_, a, b) = nodes[i] { + used[a] = true; + used[b] = true; + } + if let Node::UnoOp(_, a) = nodes[i] { + used[a] = true; + } + if let Node::TresOp(_, a, b, c) = nodes[i] { + used[a] = true; + used[b] = true; + used[c] = true; + } + } + } + + // Remove unused nodes + let n = nodes.len(); + let mut retain = used.iter(); + nodes.retain(|_| *retain.next().unwrap()); + + // Renumber references. + let mut renumber = vec![None; n]; + let mut index = 0; + for (i, &used) in used.iter().enumerate() { + if used { + renumber[i] = Some(index); + index += 1; + } + } + assert_eq!(index, nodes.len()); + for (&used, renumber) in used.iter().zip(renumber.iter()) { + assert_eq!(used, renumber.is_some()); + } + + // Renumber references. + for node in nodes.iter_mut() { + if let Node::Op(_, a, b) = node { + *a = renumber[*a].unwrap(); + *b = renumber[*b].unwrap(); + } + if let Node::UnoOp(_, a) = node { + *a = renumber[*a].unwrap(); + } + if let Node::TresOp(_, a, b, c) = node { + *a = renumber[*a].unwrap(); + *b = renumber[*b].unwrap(); + *c = renumber[*c].unwrap(); + } + } + for output in outputs.iter_mut() { + *output = renumber[*output].unwrap(); + } +} + +/// Randomly evaluate the graph +fn random_eval(nodes: &mut [Node]) -> Vec { + let mut rng = rand::thread_rng(); + let mut values = Vec::with_capacity(nodes.len()); + let mut inputs = HashMap::new(); + let mut prfs = HashMap::new(); + let mut prfs_uno = HashMap::new(); + let mut prfs_tres = HashMap::new(); + for node in nodes.iter() { + use Operation::*; + let value = match node { + // Constants evaluate to themselves + Node::Constant(c) => *c, + + Node::MontConstant(_) => unimplemented!("should not be used"), + + // Algebraic Ops are evaluated directly + // Since the field is large, by Swartz-Zippel if + // two values are the same then they are likely algebraically equal. + Node::Op(op @ (Add | Sub | Mul), a, b) => op.eval(values[*a], values[*b]), + + // Input and non-algebraic ops are random functions + // TODO: https://github.com/recmo/uint/issues/95 and use .gen_range(..M) + Node::Input(i) => *inputs.entry(*i).or_insert_with(|| rng.gen::() % M), + Node::Op(op, a, b) => *prfs + .entry((*op, values[*a], values[*b])) + .or_insert_with(|| rng.gen::() % M), + Node::UnoOp(op, a) => *prfs_uno + .entry((*op, values[*a])) + .or_insert_with(|| rng.gen::() % M), + Node::TresOp(op, a, b, c) => *prfs_tres + .entry((*op, values[*a], values[*b], values[*c])) + .or_insert_with(|| rng.gen::() % M), + }; + values.push(value); + } + values +} + +/// Value numbering +pub fn value_numbering(nodes: &mut [Node], outputs: &mut [usize]) { + assert_valid(nodes); + + // Evaluate the graph in random field elements. + let values = random_eval(nodes); + + // Find all nodes with the same value. + let mut value_map = HashMap::new(); + for (i, &value) in values.iter().enumerate() { + value_map.entry(value).or_insert_with(Vec::new).push(i); + } + + // For nodes that are the same, pick the first index. + let renumber: Vec<_> = values.into_iter().map(|v| value_map[&v][0]).collect(); + + // Renumber references. + for node in nodes.iter_mut() { + if let Node::Op(_, a, b) = node { + *a = renumber[*a]; + *b = renumber[*b]; + } + if let Node::UnoOp(_, a) = node { + *a = renumber[*a]; + } + if let Node::TresOp(_, a, b, c) = node { + *a = renumber[*a]; + *b = renumber[*b]; + *c = renumber[*c]; + } + } + for output in outputs.iter_mut() { + *output = renumber[*output]; + } +} + +/// Probabilistic constant determination +pub fn constants(nodes: &mut [Node]) { + assert_valid(nodes); + + // Evaluate the graph in random field elements. + let values_a = random_eval(nodes); + let values_b = random_eval(nodes); + + // Find all nodes with the same value. + for i in 0..nodes.len() { + if let Node::Constant(_) = nodes[i] { + continue; + } + if values_a[i] == values_b[i] { + nodes[i] = Node::Constant(values_a[i]); + } + } +} + +/// Convert to Montgomery form +pub fn montgomery_form(nodes: &mut [Node]) { + for node in nodes.iter_mut() { + use Node::*; + use Operation::*; + match node { + Constant(c) => *node = MontConstant(Fr::new((*c).into())), + MontConstant(..) => (), + Input(..) => (), + Op( + Mul | Div | Add | Sub | Idiv | Mod | Eq | Neq | Lt | Gt | Leq | Geq | Land | Lor + | Shl | Shr | Bor | Band | Bxor, + .., + ) => (), + Op(op @ Pow, ..) => unimplemented!("Operators Montgomery form: {:?}", op), + UnoOp(UnoOperation::Neg, ..) => (), + UnoOp(op, ..) => unimplemented!("Uno Operators Montgomery form: {:?}", op), + TresOp(TresOperation::TernCond, ..) => (), + } + } +} + +fn shl(a: Fr, b: Fr) -> Fr { + if b.is_zero() { + return a; + } + + if b.cmp(&Fr::from(Fr::MODULUS_BIT_SIZE)).is_ge() { + return Fr::zero(); + } + + let n = b.into_bigint().0[0] as u32; + + let mut a = a.into_bigint(); + a.muln(n); + Fr::from_bigint(a).unwrap() +} + +fn shr(a: Fr, b: Fr) -> Fr { + if b.is_zero() { + return a; + } + + match b.cmp(&Fr::from(254u64)) { + Ordering::Equal => return Fr::zero(), + Ordering::Greater => return Fr::zero(), + _ => (), + }; + + let mut n = b.into_bigint().to_bytes_le()[0]; + let mut result = a.into_bigint(); + let c = result.as_mut(); + while n >= 64 { + for i in 0..3 { + c[i as usize] = c[(i + 1) as usize]; + } + c[3] = 0; + n -= 64; + } + + if n == 0 { + return Fr::from_bigint(result).unwrap(); + } + + let mask: u64 = (1 << n) - 1; + let mut carrier: u64 = c[3] & mask; + c[3] >>= n; + for i in (0..3).rev() { + let new_carrier = c[i] & mask; + c[i] = (c[i] >> n) | (carrier << (64 - n)); + carrier = new_carrier; + } + Fr::from_bigint(result).unwrap() +} + +fn bit_and(a: Fr, b: Fr) -> Fr { + let a = a.into_bigint(); + let b = b.into_bigint(); + let c: [u64; 4] = [ + a.0[0] & b.0[0], + a.0[1] & b.0[1], + a.0[2] & b.0[2], + a.0[3] & b.0[3], + ]; + let mut d: BigInt<4> = BigInt::new(c); + if d > Fr::MODULUS { + d.sub_with_borrow(&Fr::MODULUS); + } + + Fr::from_bigint(d).unwrap() +} + +fn bit_or(a: Fr, b: Fr) -> Fr { + let a = a.into_bigint(); + let b = b.into_bigint(); + let c: [u64; 4] = [ + a.0[0] | b.0[0], + a.0[1] | b.0[1], + a.0[2] | b.0[2], + a.0[3] | b.0[3], + ]; + let mut d: BigInt<4> = BigInt::new(c); + if d > Fr::MODULUS { + d.sub_with_borrow(&Fr::MODULUS); + } + + Fr::from_bigint(d).unwrap() +} + +fn bit_xor(a: Fr, b: Fr) -> Fr { + let a = a.into_bigint(); + let b = b.into_bigint(); + let c: [u64; 4] = [ + a.0[0] ^ b.0[0], + a.0[1] ^ b.0[1], + a.0[2] ^ b.0[2], + a.0[3] ^ b.0[3], + ]; + let mut d: BigInt<4> = BigInt::new(c); + if d > Fr::MODULUS { + d.sub_with_borrow(&Fr::MODULUS); + } + + Fr::from_bigint(d).unwrap() +} + +// M / 2 +const HALF_M: U256 = + uint!(10944121435919637611123202872628637544274182200208017171849102093287904247808_U256); + +fn u_gte(a: &U256, b: &U256) -> U256 { + let a_neg = &HALF_M < a; + let b_neg = &HALF_M < b; + + match (a_neg, b_neg) { + (false, false) => U256::from(a >= b), + (true, false) => uint!(0_U256), + (false, true) => uint!(1_U256), + (true, true) => U256::from(a >= b), + } +} + +fn u_lte(a: &U256, b: &U256) -> U256 { + let a_neg = &HALF_M < a; + let b_neg = &HALF_M < b; + + match (a_neg, b_neg) { + (false, false) => U256::from(a <= b), + (true, false) => uint!(1_U256), + (false, true) => uint!(0_U256), + (true, true) => U256::from(a <= b), + } +} + +fn u_gt(a: &U256, b: &U256) -> U256 { + let a_neg = &HALF_M < a; + let b_neg = &HALF_M < b; + + match (a_neg, b_neg) { + (false, false) => U256::from(a > b), + (true, false) => uint!(0_U256), + (false, true) => uint!(1_U256), + (true, true) => U256::from(a > b), + } +} + +fn u_lt(a: &U256, b: &U256) -> U256 { + let a_neg = &HALF_M < a; + let b_neg = &HALF_M < b; + + match (a_neg, b_neg) { + (false, false) => U256::from(a < b), + (true, false) => uint!(1_U256), + (false, true) => uint!(0_U256), + (true, true) => U256::from(a < b), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ruint::uint; + use std::ops::Div; + use std::str::FromStr; + + #[test] + fn test_ok() { + let a = Fr::from(4u64); + let b = Fr::from(2u64); + let c = shl(a, b); + assert_eq!(c.cmp(&Fr::from(16u64)), Ordering::Equal) + } + + #[test] + fn test_div() { + assert_eq!( + Operation::Div.eval_fr(Fr::from(2u64), Fr::from(3u64)), + Fr::from_str( + "7296080957279758407415468581752425029516121466805344781232734728858602831873" + ) + .unwrap() + ); + + assert_eq!( + Operation::Div.eval_fr(Fr::from(6u64), Fr::from(2u64)), + Fr::from_str("3").unwrap() + ); + + assert_eq!( + Operation::Div.eval_fr(Fr::from(7u64), Fr::from(2u64)), + Fr::from_str( + "10944121435919637611123202872628637544274182200208017171849102093287904247812" + ) + .unwrap() + ); + } + + #[test] + fn test_idiv() { + assert_eq!( + Operation::Idiv.eval_fr(Fr::from(2u64), Fr::from(3u64)), + Fr::from_str("0").unwrap() + ); + + assert_eq!( + Operation::Idiv.eval_fr(Fr::from(6u64), Fr::from(2u64)), + Fr::from_str("3").unwrap() + ); + + assert_eq!( + Operation::Idiv.eval_fr(Fr::from(7u64), Fr::from(2u64)), + Fr::from_str("3").unwrap() + ); + } + + #[test] + fn test_fr_mod() { + assert_eq!( + Operation::Mod.eval_fr(Fr::from(7u64), Fr::from(2u64)), + Fr::from_str("1").unwrap() + ); + + assert_eq!( + Operation::Mod.eval_fr(Fr::from(7u64), Fr::from(9u64)), + Fr::from_str("7").unwrap() + ); + } + + #[test] + fn test_u_gte() { + let result = u_gte(&uint!(10_U256), &uint!(3_U256)); + assert_eq!(result, uint!(1_U256)); + + let result = u_gte(&uint!(3_U256), &uint!(3_U256)); + assert_eq!(result, uint!(1_U256)); + + let result = u_gte(&uint!(2_U256), &uint!(3_U256)); + assert_eq!(result, uint!(0_U256)); + + // -1 >= 3 => 0 + let result = u_gte( + &uint!( + 21888242871839275222246405745257275088548364400416034343698204186575808495616_U256 + ), + &uint!(3_U256), + ); + assert_eq!(result, uint!(0_U256)); + + // -1 >= -2 => 1 + let result = u_gte( + &uint!( + 21888242871839275222246405745257275088548364400416034343698204186575808495616_U256 + ), + &uint!( + 21888242871839275222246405745257275088548364400416034343698204186575808495615_U256 + ), + ); + assert_eq!(result, uint!(1_U256)); + + // -2 >= -1 => 0 + let result = u_gte( + &uint!( + 21888242871839275222246405745257275088548364400416034343698204186575808495615_U256 + ), + &uint!( + 21888242871839275222246405745257275088548364400416034343698204186575808495616_U256 + ), + ); + assert_eq!(result, uint!(0_U256)); + + // -2 == -2 => 1 + let result = u_gte( + &uint!( + 21888242871839275222246405745257275088548364400416034343698204186575808495615_U256 + ), + &uint!( + 21888242871839275222246405745257275088548364400416034343698204186575808495615_U256 + ), + ); + assert_eq!(result, uint!(1_U256)); + } + + #[test] + fn test_x() { + let x = M.div(uint!(2_U256)); + + println!("x: {:?}", x.as_limbs()); + println!("x: {}", M); + } + + #[test] + fn test_2() { + let nodes: Vec = vec![]; + // let node = nodes[0]; + let node = nodes.get(0); + println!("{:?}", node); + } +} diff --git a/rln/src/iden3calc/proto.rs b/rln/src/iden3calc/proto.rs new file mode 100644 index 0000000..99886fb --- /dev/null +++ b/rln/src/iden3calc/proto.rs @@ -0,0 +1,117 @@ +// This file has been generated by prost-build during compilation of the code by iden3 +// and modified manually. The *.proto file used to generate this on can be found here: +// https://github.com/iden3/circom-witnesscalc/blob/5cb365b6e4d9052ecc69d4567fcf5bc061c20e94/protos/messages.proto + +use std::collections::HashMap; + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BigUInt { + #[prost(bytes = "vec", tag = "1")] + pub value_le: Vec, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct InputNode { + #[prost(uint32, tag = "1")] + pub idx: u32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ConstantNode { + #[prost(message, optional, tag = "1")] + pub value: Option, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct UnoOpNode { + #[prost(enumeration = "UnoOp", tag = "1")] + pub op: i32, + #[prost(uint32, tag = "2")] + pub a_idx: u32, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct DuoOpNode { + #[prost(enumeration = "DuoOp", tag = "1")] + pub op: i32, + #[prost(uint32, tag = "2")] + pub a_idx: u32, + #[prost(uint32, tag = "3")] + pub b_idx: u32, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct TresOpNode { + #[prost(enumeration = "TresOp", tag = "1")] + pub op: i32, + #[prost(uint32, tag = "2")] + pub a_idx: u32, + #[prost(uint32, tag = "3")] + pub b_idx: u32, + #[prost(uint32, tag = "4")] + pub c_idx: u32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Node { + #[prost(oneof = "node::Node", tags = "1, 2, 3, 4, 5")] + pub node: Option, +} +/// Nested message and enum types in `Node`. +pub mod node { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Node { + #[prost(message, tag = "1")] + Input(super::InputNode), + #[prost(message, tag = "2")] + Constant(super::ConstantNode), + #[prost(message, tag = "3")] + UnoOp(super::UnoOpNode), + #[prost(message, tag = "4")] + DuoOp(super::DuoOpNode), + #[prost(message, tag = "5")] + TresOp(super::TresOpNode), + } +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct SignalDescription { + #[prost(uint32, tag = "1")] + pub offset: u32, + #[prost(uint32, tag = "2")] + pub len: u32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GraphMetadata { + #[prost(uint32, repeated, tag = "1")] + pub witness_signals: Vec, + #[prost(map = "string, message", tag = "2")] + pub inputs: HashMap, +} +#[derive(Clone, Copy, Debug, PartialEq, ::prost::Enumeration)] +pub enum DuoOp { + Mul = 0, + Div = 1, + Add = 2, + Sub = 3, + Pow = 4, + Idiv = 5, + Mod = 6, + Eq = 7, + Neq = 8, + Lt = 9, + Gt = 10, + Leq = 11, + Geq = 12, + Land = 13, + Lor = 14, + Shl = 15, + Shr = 16, + Bor = 17, + Band = 18, + Bxor = 19, +} + +#[derive(Clone, Copy, Debug, PartialEq, ::prost::Enumeration)] +pub enum UnoOp { + Neg = 0, + Id = 1, +} + +#[derive(Clone, Copy, Debug, PartialEq, ::prost::Enumeration)] +pub enum TresOp { + TernCond = 0, +} diff --git a/rln/src/iden3calc/storage.rs b/rln/src/iden3calc/storage.rs new file mode 100644 index 0000000..d00c994 --- /dev/null +++ b/rln/src/iden3calc/storage.rs @@ -0,0 +1,496 @@ +// This file is based on the code by iden3. Its preimage can be found here: +// https://github.com/iden3/circom-witnesscalc/blob/5cb365b6e4d9052ecc69d4567fcf5bc061c20e94/src/storage.rs + +use crate::iden3calc::{ + graph, + graph::{Operation, TresOperation, UnoOperation}, + proto, InputSignalsInfo, +}; +use ark_bn254::Fr; +use ark_ff::PrimeField; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use prost::Message; +use std::io::{Read, Write}; + +// format of the wtns.graph file: +// + magic line: wtns.graph.001 +// + 4 bytes unsigned LE 32-bit integer: number of nodes +// + series of protobuf serialized nodes. Each node prefixed by varint length +// + protobuf serialized GraphMetadata +// + 8 bytes unsigned LE 64-bit integer: offset of GraphMetadata message + +const WITNESSCALC_GRAPH_MAGIC: &[u8] = b"wtns.graph.001"; + +const MAX_VARINT_LENGTH: usize = 10; + +impl From for graph::Node { + fn from(value: proto::Node) -> Self { + match value.node.unwrap() { + proto::node::Node::Input(input_node) => graph::Node::Input(input_node.idx as usize), + proto::node::Node::Constant(constant_node) => { + let i = constant_node.value.unwrap(); + graph::Node::MontConstant(Fr::from_le_bytes_mod_order(i.value_le.as_slice())) + } + proto::node::Node::UnoOp(uno_op_node) => { + let op = proto::UnoOp::try_from(uno_op_node.op).unwrap(); + graph::Node::UnoOp(op.into(), uno_op_node.a_idx as usize) + } + proto::node::Node::DuoOp(duo_op_node) => { + let op = proto::DuoOp::try_from(duo_op_node.op).unwrap(); + graph::Node::Op( + op.into(), + duo_op_node.a_idx as usize, + duo_op_node.b_idx as usize, + ) + } + proto::node::Node::TresOp(tres_op_node) => { + let op = proto::TresOp::try_from(tres_op_node.op).unwrap(); + graph::Node::TresOp( + op.into(), + tres_op_node.a_idx as usize, + tres_op_node.b_idx as usize, + tres_op_node.c_idx as usize, + ) + } + } + } +} + +impl From<&graph::Node> for proto::node::Node { + fn from(node: &graph::Node) -> Self { + match node { + graph::Node::Input(i) => proto::node::Node::Input(proto::InputNode { idx: *i as u32 }), + graph::Node::Constant(_) => { + panic!("We are not supposed to write Constant to the witnesscalc graph. All Constant should be converted to MontConstant."); + } + graph::Node::UnoOp(op, a) => { + let op = proto::UnoOp::from(op); + proto::node::Node::UnoOp(proto::UnoOpNode { + op: op as i32, + a_idx: *a as u32, + }) + } + graph::Node::Op(op, a, b) => proto::node::Node::DuoOp(proto::DuoOpNode { + op: proto::DuoOp::from(op) as i32, + a_idx: *a as u32, + b_idx: *b as u32, + }), + graph::Node::TresOp(op, a, b, c) => proto::node::Node::TresOp(proto::TresOpNode { + op: proto::TresOp::from(op) as i32, + a_idx: *a as u32, + b_idx: *b as u32, + c_idx: *c as u32, + }), + graph::Node::MontConstant(c) => { + let bi = Into::::into(*c); + let i = proto::BigUInt { + value_le: bi.to_bytes_le(), + }; + proto::node::Node::Constant(proto::ConstantNode { value: Some(i) }) + } + } + } +} + +impl From for UnoOperation { + fn from(value: proto::UnoOp) -> Self { + match value { + proto::UnoOp::Neg => UnoOperation::Neg, + proto::UnoOp::Id => UnoOperation::Id, + } + } +} + +impl From for Operation { + fn from(value: proto::DuoOp) -> Self { + match value { + proto::DuoOp::Mul => Operation::Mul, + proto::DuoOp::Div => Operation::Div, + proto::DuoOp::Add => Operation::Add, + proto::DuoOp::Sub => Operation::Sub, + proto::DuoOp::Pow => Operation::Pow, + proto::DuoOp::Idiv => Operation::Idiv, + proto::DuoOp::Mod => Operation::Mod, + proto::DuoOp::Eq => Operation::Eq, + proto::DuoOp::Neq => Operation::Neq, + proto::DuoOp::Lt => Operation::Lt, + proto::DuoOp::Gt => Operation::Gt, + proto::DuoOp::Leq => Operation::Leq, + proto::DuoOp::Geq => Operation::Geq, + proto::DuoOp::Land => Operation::Land, + proto::DuoOp::Lor => Operation::Lor, + proto::DuoOp::Shl => Operation::Shl, + proto::DuoOp::Shr => Operation::Shr, + proto::DuoOp::Bor => Operation::Bor, + proto::DuoOp::Band => Operation::Band, + proto::DuoOp::Bxor => Operation::Bxor, + } + } +} + +impl From for graph::TresOperation { + fn from(value: proto::TresOp) -> Self { + match value { + proto::TresOp::TernCond => TresOperation::TernCond, + } + } +} + +pub fn serialize_witnesscalc_graph( + mut w: T, + nodes: &Vec, + witness_signals: &[usize], + input_signals: &InputSignalsInfo, +) -> std::io::Result<()> { + let mut ptr = 0usize; + w.write_all(WITNESSCALC_GRAPH_MAGIC).unwrap(); + ptr += WITNESSCALC_GRAPH_MAGIC.len(); + + w.write_u64::(nodes.len() as u64)?; + ptr += 8; + + let metadata = proto::GraphMetadata { + witness_signals: witness_signals + .iter() + .map(|x| *x as u32) + .collect::>(), + inputs: input_signals + .iter() + .map(|(k, v)| { + let sig = proto::SignalDescription { + offset: v.0 as u32, + len: v.1 as u32, + }; + (k.clone(), sig) + }) + .collect(), + }; + + // capacity of buf should be enough to hold the largest message + 10 bytes + // of varint length + let mut buf = Vec::with_capacity(metadata.encoded_len() + MAX_VARINT_LENGTH); + + for node in nodes { + let node_pb = proto::Node { + node: Some(proto::node::Node::from(node)), + }; + + assert_eq!(buf.len(), 0); + node_pb.encode_length_delimited(&mut buf)?; + ptr += buf.len(); + + w.write_all(&buf)?; + buf.clear(); + } + + metadata.encode_length_delimited(&mut buf)?; + w.write_all(&buf)?; + buf.clear(); + + w.write_u64::(ptr as u64)?; + + Ok(()) +} + +fn read_message_length(rw: &mut WriteBackReader) -> std::io::Result { + let mut buf = [0u8; MAX_VARINT_LENGTH]; + let bytes_read = rw.read(&mut buf)?; + if bytes_read == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Unexpected EOF", + )); + } + + let len_delimiter = prost::decode_length_delimiter(buf.as_ref())?; + + let lnln = prost::length_delimiter_len(len_delimiter); + + if lnln < bytes_read { + rw.write_all(&buf[lnln..bytes_read])?; + } + + Ok(len_delimiter) +} + +fn read_message( + rw: &mut WriteBackReader, +) -> std::io::Result { + let ln = read_message_length(rw)?; + let mut buf = vec![0u8; ln]; + let bytes_read = rw.read(&mut buf)?; + if bytes_read != ln { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Unexpected EOF", + )); + } + + let msg = prost::Message::decode(&buf[..])?; + + Ok(msg) +} + +pub fn deserialize_witnesscalc_graph( + r: impl Read, +) -> std::io::Result<(Vec, Vec, InputSignalsInfo)> { + let mut br = WriteBackReader::new(r); + let mut magic = [0u8; WITNESSCALC_GRAPH_MAGIC.len()]; + + br.read_exact(&mut magic)?; + + if !magic.eq(WITNESSCALC_GRAPH_MAGIC) { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Invalid magic", + )); + } + + let nodes_num = br.read_u64::()?; + let mut nodes = Vec::with_capacity(nodes_num as usize); + for _ in 0..nodes_num { + let n: proto::Node = read_message(&mut br)?; + let n2: graph::Node = n.into(); + nodes.push(n2); + } + + let md: proto::GraphMetadata = read_message(&mut br)?; + + let witness_signals = md + .witness_signals + .iter() + .map(|x| *x as usize) + .collect::>(); + + let input_signals = md + .inputs + .iter() + .map(|(k, v)| (k.clone(), (v.offset as usize, v.len as usize))) + .collect::(); + + Ok((nodes, witness_signals, input_signals)) +} + +struct WriteBackReader { + reader: R, + buffer: Vec, +} + +impl WriteBackReader { + fn new(reader: R) -> Self { + WriteBackReader { + reader, + buffer: Vec::new(), + } + } +} + +impl Read for WriteBackReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + if buf.is_empty() { + return Ok(0); + } + + let mut n = 0usize; + + if !self.buffer.is_empty() { + n = std::cmp::min(buf.len(), self.buffer.len()); + self.buffer[self.buffer.len() - n..] + .iter() + .rev() + .enumerate() + .for_each(|(i, x)| { + buf[i] = *x; + }); + self.buffer.truncate(self.buffer.len() - n); + } + + while n < buf.len() { + let m = self.reader.read(&mut buf[n..])?; + if m == 0 { + break; + } + n += m; + } + + Ok(n) + } +} + +impl Write for WriteBackReader { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.buffer.reserve(buf.len()); + self.buffer.extend(buf.iter().rev()); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use byteorder::ByteOrder; + use core::str::FromStr; + use graph::{Operation, TresOperation, UnoOperation}; + use std::collections::HashMap; + + #[test] + fn test_read_message() { + let mut buf = Vec::new(); + let n1 = proto::Node { + node: Some(proto::node::Node::Input(proto::InputNode { idx: 1 })), + }; + n1.encode_length_delimited(&mut buf).unwrap(); + + let n2 = proto::Node { + node: Some(proto::node::Node::Input(proto::InputNode { idx: 2 })), + }; + n2.encode_length_delimited(&mut buf).unwrap(); + + let mut reader = std::io::Cursor::new(&buf); + + let mut rw = WriteBackReader::new(&mut reader); + + let got_n1: proto::Node = read_message(&mut rw).unwrap(); + assert!(n1.eq(&got_n1)); + + let got_n2: proto::Node = read_message(&mut rw).unwrap(); + assert!(n2.eq(&got_n2)); + + assert_eq!(reader.position(), buf.len() as u64); + } + + #[test] + fn test_read_message_variant() { + let nodes = vec![ + proto::Node { + node: Some(proto::node::Node::from(&graph::Node::Input(0))), + }, + proto::Node { + node: Some(proto::node::Node::from(&graph::Node::MontConstant( + Fr::from_str("1").unwrap(), + ))), + }, + proto::Node { + node: Some(proto::node::Node::from(&graph::Node::UnoOp( + UnoOperation::Id, + 4, + ))), + }, + proto::Node { + node: Some(proto::node::Node::from(&graph::Node::Op( + Operation::Mul, + 5, + 6, + ))), + }, + proto::Node { + node: Some(proto::node::Node::from(&graph::Node::TresOp( + TresOperation::TernCond, + 7, + 8, + 9, + ))), + }, + ]; + + let mut buf = Vec::new(); + for n in &nodes { + n.encode_length_delimited(&mut buf).unwrap(); + } + + let mut nodes_got: Vec = Vec::new(); + let mut reader = std::io::Cursor::new(&buf); + let mut rw = WriteBackReader::new(&mut reader); + for _ in 0..nodes.len() { + nodes_got.push(read_message(&mut rw).unwrap()); + } + + assert_eq!(nodes, nodes_got); + } + + #[test] + fn test_write_back_reader() { + let data = [1u8, 2, 3, 4, 5, 6]; + let mut r = WriteBackReader::new(std::io::Cursor::new(&data)); + + let buf = &mut [0u8; 5]; + r.read(buf).unwrap(); + assert_eq!(buf, &[1, 2, 3, 4, 5]); + + // return [4, 5] to reader + r.write(&buf[3..]).unwrap(); + // return [2, 3] to reader + r.write(&buf[1..3]).unwrap(); + + buf.fill(0); + + // read 3 bytes, expect [2, 3, 4] after returns + let mut n = r.read(&mut buf[..3]).unwrap(); + assert_eq!(n, 3); + assert_eq!(buf, &[2, 3, 4, 0, 0]); + + buf.fill(0); + + // read everything left in reader + n = r.read(buf).unwrap(); + assert_eq!(n, 2); + assert_eq!(buf, &[5, 6, 0, 0, 0]); + } + + #[test] + fn test_deserialize_inputs() { + let nodes = vec![ + graph::Node::Input(0), + graph::Node::MontConstant(Fr::from_str("1").unwrap()), + graph::Node::UnoOp(UnoOperation::Id, 4), + graph::Node::Op(Operation::Mul, 5, 6), + graph::Node::TresOp(TresOperation::TernCond, 7, 8, 9), + ]; + + let witness_signals = vec![4, 1]; + + let mut input_signals: InputSignalsInfo = HashMap::new(); + input_signals.insert("sig1".to_string(), (1, 3)); + input_signals.insert("sig2".to_string(), (5, 1)); + + let mut tmp = Vec::new(); + serialize_witnesscalc_graph(&mut tmp, &nodes, &witness_signals, &input_signals).unwrap(); + + let mut reader = std::io::Cursor::new(&tmp); + + let (nodes_res, witness_signals_res, input_signals_res) = + deserialize_witnesscalc_graph(&mut reader).unwrap(); + + assert_eq!(nodes, nodes_res); + assert_eq!(input_signals, input_signals_res); + assert_eq!(witness_signals, witness_signals_res); + + let metadata_start = LittleEndian::read_u64(&tmp[tmp.len() - 8..]); + + let mt_reader = std::io::Cursor::new(&tmp[metadata_start as usize..]); + let mut rw = WriteBackReader::new(mt_reader); + let metadata: proto::GraphMetadata = read_message(&mut rw).unwrap(); + + let metadata_want = proto::GraphMetadata { + witness_signals: vec![4, 1], + inputs: input_signals + .iter() + .map(|(k, v)| { + ( + k.clone(), + proto::SignalDescription { + offset: v.0 as u32, + len: v.1 as u32, + }, + ) + }) + .collect(), + }; + + assert_eq!(metadata, metadata_want); + } +} diff --git a/rln/src/lib.rs b/rln/src/lib.rs index 292db57..06d8d4a 100644 --- a/rln/src/lib.rs +++ b/rln/src/lib.rs @@ -2,6 +2,7 @@ pub mod circuit; pub mod hashers; +pub mod iden3calc; #[cfg(feature = "pmtree-ft")] pub mod pm_tree_adapter; pub mod poseidon_tree; diff --git a/rln/src/protocol.rs b/rln/src/protocol.rs index 46aa853..7818aed 100644 --- a/rln/src/protocol.rs +++ b/rln/src/protocol.rs @@ -1,6 +1,6 @@ // This crate collects all the underlying primitives used to implement RLN -use ark_circom::{CircomReduction, WitnessCalculator}; +use ark_circom::CircomReduction; use ark_groth16::{prepare_verifying_key, Groth16, Proof as ArkProof, ProvingKey, VerifyingKey}; use ark_relations::r1cs::ConstraintMatrices; use ark_relations::r1cs::SynthesisError; @@ -11,20 +11,17 @@ use num_bigint::BigInt; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; use serde::{Deserialize, Serialize}; -#[cfg(not(target_arch = "wasm32"))] -use std::sync::Mutex; -#[cfg(debug_assertions)] +#[cfg(test)] use std::time::Instant; use thiserror::Error; use tiny_keccak::{Hasher as _, Keccak}; -use crate::circuit::{Curve, Fr}; +use crate::circuit::{calculate_rln_witness, Curve, Fr}; use crate::hashers::hash_to_field; use crate::hashers::poseidon_hash; use crate::poseidon_tree::*; use crate::public::RLN_IDENTIFIER; use crate::utils::*; -use cfg_if::cfg_if; use utils::{ZerokitMerkleProof, ZerokitMerkleTree}; /////////////////////////////////////////////////////// @@ -544,13 +541,13 @@ pub fn generate_proof_with_witness( proving_key: &(ProvingKey, ConstraintMatrices), ) -> Result, ProofError> { // If in debug mode, we measure and later print time take to compute witness - #[cfg(debug_assertions)] + #[cfg(test)] let now = Instant::now(); let full_assignment = calculate_witness_element::(witness).map_err(ProofError::WitnessError)?; - #[cfg(debug_assertions)] + #[cfg(test)] println!("witness generation took: {:.2?}", now.elapsed()); // Random Values @@ -559,7 +556,7 @@ pub fn generate_proof_with_witness( let s = Fr::rand(&mut rng); // If in debug mode, we measure and later print time take to compute proof - #[cfg(debug_assertions)] + #[cfg(test)] let now = Instant::now(); let proof = Groth16::<_, CircomReduction>::create_proof_with_reduction_and_matrices( @@ -572,7 +569,7 @@ pub fn generate_proof_with_witness( full_assignment.as_slice(), )?; - #[cfg(debug_assertions)] + #[cfg(test)] println!("proof generation took: {:.2?}", now.elapsed()); Ok(proof) @@ -628,8 +625,6 @@ pub fn inputs_for_witness_calculation( /// /// Returns a [`ProofError`] if proving fails. pub fn generate_proof( - #[cfg(not(target_arch = "wasm32"))] witness_calculator: &Mutex, - #[cfg(target_arch = "wasm32")] witness_calculator: &mut WitnessCalculator, proving_key: &(ProvingKey, ConstraintMatrices), rln_witness: &RLNWitnessInput, ) -> Result, ProofError> { @@ -638,24 +633,11 @@ pub fn generate_proof( .map(|(name, values)| (name.to_string(), values)); // If in debug mode, we measure and later print time take to compute witness - #[cfg(debug_assertions)] + #[cfg(test)] let now = Instant::now(); + let full_assignment = calculate_rln_witness(inputs); - cfg_if! { - if #[cfg(target_arch = "wasm32")] { - let full_assignment = witness_calculator - .calculate_witness_element::(inputs, false) - .map_err(ProofError::WitnessError)?; - } else { - let full_assignment = witness_calculator - .lock() - .expect("witness_calculator mutex should not get poisoned") - .calculate_witness_element::(inputs, false) - .map_err(ProofError::WitnessError)?; - } - } - - #[cfg(debug_assertions)] + #[cfg(test)] println!("witness generation took: {:.2?}", now.elapsed()); // Random Values @@ -664,7 +646,7 @@ pub fn generate_proof( let s = Fr::rand(&mut rng); // If in debug mode, we measure and later print time take to compute proof - #[cfg(debug_assertions)] + #[cfg(test)] let now = Instant::now(); let proof = Groth16::<_, CircomReduction>::create_proof_with_reduction_and_matrices( &proving_key.0, @@ -676,7 +658,7 @@ pub fn generate_proof( full_assignment.as_slice(), )?; - #[cfg(debug_assertions)] + #[cfg(test)] println!("proof generation took: {:.2?}", now.elapsed()); Ok(proof) @@ -707,12 +689,12 @@ pub fn verify_proof( //let pr: ArkProof = (*proof).into(); // If in debug mode, we measure and later print time take to verify proof - #[cfg(debug_assertions)] + #[cfg(test)] let now = Instant::now(); let verified = Groth16::<_, CircomReduction>::verify_proof(&pvk, proof, &inputs)?; - #[cfg(debug_assertions)] + #[cfg(test)] println!("verify took: {:.2?}", now.elapsed()); Ok(verified) diff --git a/rln/src/public.rs b/rln/src/public.rs index ca575d4..b9bf19a 100644 --- a/rln/src/public.rs +++ b/rln/src/public.rs @@ -16,13 +16,10 @@ use utils::{ZerokitMerkleProof, ZerokitMerkleTree}; cfg_if! { if #[cfg(not(target_arch = "wasm32"))] { use std::default::Default; - use std::sync::Mutex; - use crate::circuit::{circom_from_folder, vk_from_folder, circom_from_raw, zkey_from_folder, TEST_TREE_HEIGHT}; - use ark_circom::WitnessCalculator; + use crate::circuit::{vk_from_folder, zkey_from_folder, TEST_TREE_HEIGHT}; use crate::poseidon_tree::PoseidonTree; use serde_json::{json, Value}; - use utils::{Hasher}; - use std::sync::Arc; + use utils::Hasher; use std::str::FromStr; } else { use std::marker::*; @@ -45,12 +42,6 @@ pub struct RLN { pub(crate) verification_key: VerifyingKey, #[cfg(not(feature = "stateless"))] pub(crate) tree: PoseidonTree, - - // The witness calculator can't be loaded in zerokit. Since this struct - // contains a lifetime, a PhantomData is necessary to avoid a compiler - // error since the lifetime is not being used - #[cfg(not(target_arch = "wasm32"))] - pub(crate) witness_calculator: Arc>, #[cfg(target_arch = "wasm32")] _marker: PhantomData<()>, } @@ -81,7 +72,6 @@ impl RLN { let rln_config: Value = serde_json::from_str(&String::from_utf8(input)?)?; let tree_config = rln_config["tree_config"].to_string(); - let witness_calculator = circom_from_folder(); let proving_key = zkey_from_folder(); let verification_key = vk_from_folder(); @@ -100,7 +90,6 @@ impl RLN { )?; Ok(RLN { - witness_calculator: witness_calculator.to_owned(), proving_key: proving_key.to_owned(), verification_key: verification_key.to_owned(), #[cfg(not(feature = "stateless"))] @@ -122,12 +111,10 @@ impl RLN { #[cfg(all(not(target_arch = "wasm32"), feature = "stateless"))] pub fn new() -> Result { #[cfg(not(target_arch = "wasm32"))] - let witness_calculator = circom_from_folder(); let proving_key = zkey_from_folder(); let verification_key = vk_from_folder(); Ok(RLN { - witness_calculator: witness_calculator.to_owned(), proving_key: proving_key.to_owned(), verification_key: verification_key.to_owned(), #[cfg(target_arch = "wasm32")] @@ -139,7 +126,6 @@ impl RLN { /// /// Input parameters are /// - `tree_height`: the height of the internal Merkle tree - /// - `circom_vec`: a byte vector containing the ZK circuit (`rln.wasm`) as binary file /// - `zkey_vec`: a byte vector containing to the proving key (`rln_final.zkey`) or (`rln_final.arkzkey`) as binary file /// - `vk_vec`: a byte vector containing to the verification key (`verification_key.arkvkey`) as binary file /// - `tree_config_input`: a reader for a string containing a json with the merkle tree configuration @@ -153,36 +139,33 @@ impl RLN { /// let resources_folder = "./resources/tree_height_20/"; /// /// let mut resources: Vec> = Vec::new(); - /// for filename in ["rln.wasm", "rln_final.zkey", "verification_key.arkvkey"] { + /// for filename in ["rln_final.zkey", "verification_key.arkvkey"] { /// let fullpath = format!("{resources_folder}{filename}"); /// let mut file = File::open(&fullpath).expect("no file found"); /// let metadata = std::fs::metadata(&fullpath).expect("unable to read metadata"); /// let mut buffer = vec![0; metadata.len() as usize]; /// file.read_exact(&mut buffer).expect("buffer overflow"); /// resources.push(buffer); - /// let tree_config = "{}".to_string(); - /// let tree_config_input = &Buffer::from(tree_config.as_bytes()); /// } /// + /// let tree_config = "".to_string(); + /// let tree_config_buffer = &Buffer::from(tree_config.as_bytes()); + /// /// let mut rln = RLN::new_with_params( /// tree_height, /// resources[0].clone(), /// resources[1].clone(), - /// resources[2].clone(), - /// tree_config_input, + /// tree_config_buffer, /// ); /// ``` #[cfg(all(not(target_arch = "wasm32"), not(feature = "stateless")))] pub fn new_with_params( tree_height: usize, - circom_vec: Vec, zkey_vec: Vec, vk_vec: Vec, mut tree_config_input: R, ) -> Result { #[cfg(not(target_arch = "wasm32"))] - let witness_calculator = circom_from_raw(&circom_vec)?; - let proving_key = zkey_from_raw(&zkey_vec)?; let verification_key = vk_from_raw(&vk_vec, &zkey_vec)?; @@ -204,7 +187,6 @@ impl RLN { )?; Ok(RLN { - witness_calculator, proving_key, verification_key, #[cfg(not(feature = "stateless"))] @@ -217,7 +199,6 @@ impl RLN { /// Creates a new stateless RLN object by passing circuit resources as byte vectors. /// /// Input parameters are - /// - `circom_vec`: a byte vector containing the ZK circuit (`rln.wasm`) as binary file /// - `zkey_vec`: a byte vector containing to the proving key (`rln_final.zkey`) or (`rln_final.arkzkey`) as binary file /// - `vk_vec`: a byte vector containing to the verification key (`verification_key.arkvkey`) as binary file /// @@ -229,7 +210,7 @@ impl RLN { /// let resources_folder = "./resources/tree_height_20/"; /// /// let mut resources: Vec> = Vec::new(); - /// for filename in ["rln.wasm", "rln_final.zkey", "verification_key.arkvkey"] { + /// for filename in ["rln_final.zkey", "verification_key.arkvkey"] { /// let fullpath = format!("{resources_folder}{filename}"); /// let mut file = File::open(&fullpath).expect("no file found"); /// let metadata = std::fs::metadata(&fullpath).expect("unable to read metadata"); @@ -241,18 +222,14 @@ impl RLN { /// let mut rln = RLN::new_with_params( /// resources[0].clone(), /// resources[1].clone(), - /// resources[2].clone(), /// ); /// ``` #[cfg(all(not(target_arch = "wasm32"), feature = "stateless"))] - pub fn new_with_params(circom_vec: Vec, zkey_vec: Vec, vk_vec: Vec) -> Result { - let witness_calculator = circom_from_raw(&circom_vec)?; - + pub fn new_with_params(zkey_vec: Vec, vk_vec: Vec) -> Result { let proving_key = zkey_from_raw(&zkey_vec)?; let verification_key = vk_from_raw(&vk_vec, &zkey_vec)?; Ok(RLN { - witness_calculator, proving_key, verification_key, }) @@ -784,13 +761,7 @@ impl RLN { input_data.read_to_end(&mut serialized)?; let (rln_witness, _) = deserialize_witness(&serialized)?; - /* - if self.witness_calculator.is_none() { - self.witness_calculator = CIRCOM(&self.resources_folder); - } - */ - - let proof = generate_proof(&self.witness_calculator, &self.proving_key, &rln_witness)?; + let proof = generate_proof(&self.proving_key, &rln_witness)?; // Note: we export a serialization of ark-groth16::Proof not semaphore::Proof proof.serialize_compressed(&mut output_data)?; @@ -913,7 +884,7 @@ impl RLN { let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte)?; let proof_values = proof_values_from_witness(&rln_witness)?; - let proof = generate_proof(&self.witness_calculator, &self.proving_key, &rln_witness)?; + let proof = generate_proof(&self.proving_key, &rln_witness)?; // Note: we export a serialization of ark-groth16::Proof not semaphore::Proof // This proof is compressed, i.e. 128 bytes long @@ -961,7 +932,7 @@ impl RLN { let (rln_witness, _) = deserialize_witness(&witness_byte)?; let proof_values = proof_values_from_witness(&rln_witness)?; - let proof = generate_proof(&self.witness_calculator, &self.proving_key, &rln_witness)?; + let proof = generate_proof(&self.proving_key, &rln_witness)?; // Note: we export a serialization of ark-groth16::Proof not semaphore::Proof // This proof is compressed, i.e. 128 bytes long diff --git a/rln/tests/ffi.rs b/rln/tests/ffi.rs index b116148..dd6ad71 100644 --- a/rln/tests/ffi.rs +++ b/rln/tests/ffi.rs @@ -408,15 +408,6 @@ mod test { // We obtain the root from the RLN instance let root_rln_folder = get_tree_root(rln_pointer); - // Reading the raw data from the files required for instantiating a RLN instance using raw data - let circom_path = "./resources/tree_height_20/rln.wasm"; - let mut circom_file = File::open(&circom_path).expect("no file found"); - let metadata = std::fs::metadata(&circom_path).expect("unable to read metadata"); - let mut circom_buffer = vec![0; metadata.len() as usize]; - circom_file - .read_exact(&mut circom_buffer) - .expect("buffer overflow"); - #[cfg(feature = "arkzkey")] let zkey_path = "./resources/tree_height_20/rln_final.arkzkey"; #[cfg(not(feature = "arkzkey"))] @@ -434,7 +425,6 @@ mod test { let mut vk_buffer = vec![0; metadata.len() as usize]; vk_file.read_exact(&mut vk_buffer).expect("buffer overflow"); - let circom_data = &Buffer::from(&circom_buffer[..]); let zkey_data = &Buffer::from(&zkey_buffer[..]); let vk_data = &Buffer::from(&vk_buffer[..]); @@ -444,7 +434,6 @@ mod test { let tree_config_buffer = &Buffer::from(tree_config.as_bytes()); let success = new_with_params( TEST_TREE_HEIGHT, - circom_data, zkey_data, vk_data, tree_config_buffer, diff --git a/rln/tests/protocol.rs b/rln/tests/protocol.rs index 9a5144a..26cea1a 100644 --- a/rln/tests/protocol.rs +++ b/rln/tests/protocol.rs @@ -2,7 +2,7 @@ mod test { use ark_ff::BigInt; use rln::circuit::zkey_from_folder; - use rln::circuit::{circom_from_folder, vk_from_folder, Fr, TEST_TREE_HEIGHT}; + use rln::circuit::{vk_from_folder, Fr, TEST_TREE_HEIGHT}; use rln::hashers::{hash_to_field, poseidon_hash}; use rln::poseidon_tree::PoseidonTree; use rln::protocol::*; @@ -129,7 +129,6 @@ mod test { // We generate all relevant keys let proving_key = zkey_from_folder(); let verification_key = vk_from_folder(); - let builder = circom_from_folder(); // We compute witness from the json input let rln_witness = get_test_witness(); @@ -138,7 +137,7 @@ mod test { assert_eq!(rln_witness_deser, rln_witness); // Let's generate a zkSNARK proof - let proof = generate_proof(builder, &proving_key, &rln_witness_deser).unwrap(); + let proof = generate_proof(&proving_key, &rln_witness_deser).unwrap(); let proof_values = proof_values_from_witness(&rln_witness_deser).unwrap(); // Let's verify the proof @@ -158,10 +157,9 @@ mod test { // We generate all relevant keys let proving_key = zkey_from_folder(); let verification_key = vk_from_folder(); - let builder = circom_from_folder(); // Let's generate a zkSNARK proof - let proof = generate_proof(builder, &proving_key, &rln_witness_deser).unwrap(); + let proof = generate_proof(&proving_key, &rln_witness_deser).unwrap(); let proof_values = proof_values_from_witness(&rln_witness_deser).unwrap();