From 17ad6adf07fa716bb38c8a1be2ca4888544aabb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roger=20Taul=C3=A9=20Buxadera?= <55488871+RogerTaule@users.noreply.github.com> Date: Tue, 4 Feb 2025 17:53:37 +0100 Subject: [PATCH] Pre develop tod (#157) * barriers and threads managment * Overwritting const tree file whenever needed (#148) * Overwritting const tree file whenever needed * Properly handling setup writing for when distributed * Cargo clippy * Improving print summary with memory setup * Cargo fmt and clippy * Minor fixes recursivef and final (#150) * Minor opts * Minor fix starks_api * Feature/fixed cols bin (#152) * Fixed cols binary * Cargo fmt * Keep working * Cargo fmt and clippy * Adding file * Minor fix * Debug mode fast implemented (#151) * Register_std separated (#154) * Fix pil-helpers (#156) * Fix pil-helpers * Cargo.toml * Cargo lock * Adding fixed to pil-helpers * Allow clippy for traces file * Fixing std_range_check.pil for pil-helpers * Fixing connection traces for fixed * Poseidon 2 (#153) * Working in poseidon2 * Update ci test * Disallowing poseidon2 when AVX512 * Avoid clone (#158) * Minor changes * Minor modification (#159) * add parallel iterators to trace macro (#160) * add parallel iterators to trace macro * Fix test std and cargo fmt --------- Co-authored-by: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> * adding rayon to pil helpers * Adding debug feature (#161) * Adding debug feature * Fix pil-helpers * Update ci --------- Co-authored-by: rickb80 <75077385+rickb80@users.noreply.github.com> Co-authored-by: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> --- Cargo.lock | 106 ++- Cargo.toml | 1 + cli/assets/templates/pil_helpers_trace.rs.tt | 29 +- cli/src/commands/get_constraints.rs | 3 +- cli/src/commands/pil_helpers.rs | 55 +- cli/src/commands/prove.rs | 7 +- cli/src/commands/verify_constraints.rs | 7 +- common/src/air_instances_repository.rs | 21 +- common/src/constraints.rs | 4 +- common/src/custom_commits.rs | 4 +- common/src/distribution_ctx.rs | 30 + common/src/fixed_cols.rs | 54 ++ common/src/lib.rs | 2 + common/src/proof_ctx.rs | 47 +- common/src/prover.rs | 12 +- common/src/setup.rs | 207 +++--- common/src/setup_ctx.rs | 77 ++- common/src/std_mode.rs | 13 +- common/src/utils.rs | 32 +- examples/fibonacci-square/Cargo.toml | 5 + examples/fibonacci-square/src/fibonacci.rs | 9 +- examples/fibonacci-square/src/module.rs | 23 +- .../src/pil_helpers/traces.rs | 51 +- hints/src/global_hints.rs | 18 +- hints/src/hints.rs | 36 +- macros/Cargo.toml | 5 + macros/src/lib.rs | 73 ++- .../lib/std/pil/std_connection.pil | 12 +- .../lib/std/pil/std_range_check.pil | 5 +- pil2-components/lib/std/rs/Cargo.toml | 1 + pil2-components/lib/std/rs/src/common.rs | 31 +- pil2-components/lib/std/rs/src/debug.rs | 148 ++++- .../rs/src/range_check/specified_ranges.rs | 33 +- .../std/rs/src/range_check/std_range_check.rs | 8 +- .../lib/std/rs/src/range_check/u16air.rs | 33 +- .../lib/std/rs/src/range_check/u8air.rs | 33 +- pil2-components/lib/std/rs/src/std.rs | 14 +- pil2-components/lib/std/rs/src/std_prod.rs | 239 +++++-- pil2-components/lib/std/rs/src/std_sum.rs | 237 +++++-- pil2-components/test/simple/rs/Cargo.toml | 1 + .../test/simple/rs/src/pil_helpers/traces.rs | 11 + .../test/std/connection/rs/Cargo.toml | 1 + .../connection/rs/src/pil_helpers/traces.rs | 21 +- .../test/std/diff_buses/rs/Cargo.toml | 1 + .../test/std/direct_update/rs/Cargo.toml | 2 + pil2-components/test/std/lookup/rs/Cargo.toml | 1 + .../std/lookup/rs/src/pil_helpers/traces.rs | 27 + .../test/std/permutation/rs/Cargo.toml | 1 + .../permutation/rs/src/pil_helpers/traces.rs | 27 +- .../test/std/range_check/rs/Cargo.toml | 1 + .../range_check/rs/src/pil_helpers/traces.rs | 51 ++ pil2-stark/lib/include/starks_lib.h | 19 +- pil2-stark/src/api/starks_api.cpp | 68 +- pil2-stark/src/api/starks_api.hpp | 19 +- pil2-stark/src/config/zkglobals.cpp | 2 +- pil2-stark/src/config/zkglobals.hpp | 4 +- pil2-stark/src/goldilocks/benchs/bench.cpp | 168 ++++- .../goldilocks/src/poseidon2_goldilocks.cpp | 582 +++++++++++++++++ .../goldilocks/src/poseidon2_goldilocks.hpp | 199 ++++++ .../src/poseidon2_goldilocks_avx.hpp | 118 ++++ .../src/poseidon2_goldilocks_avx512.hpp | 107 ++++ .../src/poseidon2_goldilocks_constants.hpp | 145 +++++ pil2-stark/src/goldilocks/tests/tests.cpp | 39 +- pil2-stark/src/starkpil/const_pols.hpp | 96 +-- pil2-stark/src/starkpil/fixed_cols.hpp | 50 ++ .../src/starkpil/gen_recursive_proof.hpp | 4 +- .../starkpil/merkleTree/merkleTreeBN128.cpp | 37 +- .../starkpil/merkleTree/merkleTreeBN128.hpp | 4 +- .../src/starkpil/merkleTree/merkleTreeGL.cpp | 47 +- .../src/starkpil/merkleTree/merkleTreeGL.hpp | 9 +- pil2-stark/src/starkpil/proof_stark.hpp | 32 +- pil2-stark/src/starkpil/stark_info.cpp | 53 +- pil2-stark/src/starkpil/starks.cpp | 48 +- pil2-stark/src/starkpil/starks.hpp | 8 +- .../src/starkpil/transcript/transcriptGL.cpp | 2 +- .../src/starkpil/transcript/transcriptGL.hpp | 2 +- pil2-stark/src/utils/utils.cpp | 14 +- pil2-stark/src/utils/utils.hpp | 2 +- proofman/Cargo.toml | 2 + proofman/src/proofman.rs | 603 ++++++++++++------ proofman/src/recursion.rs | 28 +- proofman/src/verify.rs | 2 +- proofman/src/verify_constraints.rs | 6 +- provers/stark/src/stark_prover.rs | 82 +-- provers/starks-lib-c/bindings_starks.rs | 52 +- provers/starks-lib-c/src/ffi_starks.rs | 133 +++- util/src/lib.rs | 3 +- witness/src/witness_component.rs | 10 +- witness/src/witness_library.rs | 2 +- witness/src/witness_manager.rs | 61 +- 90 files changed, 3773 insertions(+), 959 deletions(-) create mode 100644 common/src/fixed_cols.rs create mode 100644 pil2-stark/src/goldilocks/src/poseidon2_goldilocks.cpp create mode 100644 pil2-stark/src/goldilocks/src/poseidon2_goldilocks.hpp create mode 100644 pil2-stark/src/goldilocks/src/poseidon2_goldilocks_avx.hpp create mode 100644 pil2-stark/src/goldilocks/src/poseidon2_goldilocks_avx512.hpp create mode 100644 pil2-stark/src/goldilocks/src/poseidon2_goldilocks_constants.hpp create mode 100644 pil2-stark/src/starkpil/fixed_cols.hpp diff --git a/Cargo.lock b/Cargo.lock index aec6ed199..0e44c5f78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,11 +52,12 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "3.0.6" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", + "once_cell", "windows-sys", ] @@ -97,9 +98,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.7.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1be3f42a67d6d345ecd59f675f3f012d6974981560836e938c22b424b85ce1be" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" [[package]] name = "build-probe-mpi" @@ -125,9 +126,9 @@ checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "cc" -version = "1.2.7" +version = "1.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a012a0df96dd6d06ba9a1b29d6402d1a5d77c6befd2566afdc26e10603dc93d7" +checksum = "13208fcbb66eaeffe09b99fffbe1af420f00a7b35aa99ad683dfc1aa76145229" dependencies = [ "shlex", ] @@ -160,9 +161,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" +checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796" dependencies = [ "clap_builder", "clap_derive", @@ -170,9 +171,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" +checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" dependencies = [ "anstream", "anstyle", @@ -333,6 +334,7 @@ dependencies = [ "proofman-macros", "rayon", "serde", + "serde_arrays", "serde_json", "witness", ] @@ -357,7 +359,19 @@ checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets", ] [[package]] @@ -395,9 +409,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown", @@ -488,9 +502,9 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "memchr" @@ -723,6 +737,7 @@ dependencies = [ name = "pil-std-lib" version = "0.1.0" dependencies = [ + "colored", "log", "num-bigint", "num-traits", @@ -768,9 +783,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.27" +version = "0.2.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "483f8c21f64f3ea09fe0f30f5d48c3e8eefe5dac9129f0075f76593b4c1da705" +checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" dependencies = [ "proc-macro2", "syn", @@ -778,9 +793,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] @@ -798,6 +813,7 @@ dependencies = [ "p3-goldilocks", "proofman-common", "proofman-hints", + "proofman-macros", "proofman-starks-lib-c", "proofman-util", "rayon", @@ -864,6 +880,7 @@ version = "0.1.0" dependencies = [ "proc-macro2", "quote", + "rayon", "syn", ] @@ -971,7 +988,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", ] [[package]] @@ -1031,9 +1048,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.38.43" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags", "errno", @@ -1044,9 +1061,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "serde" @@ -1057,6 +1074,15 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_arrays" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38636132857f68ec3d5f3eb121166d2af33cb55174c4d5ff645db6165cbef0fd" +dependencies = [ + "serde", +] + [[package]] name = "serde_derive" version = "1.0.217" @@ -1070,9 +1096,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.135" +version = "1.0.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" +checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" dependencies = [ "itoa", "memchr", @@ -1153,13 +1179,13 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.15.0" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" +checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" dependencies = [ "cfg-if", "fastrand", - "getrandom", + "getrandom 0.3.1", "once_cell", "rustix", "windows-sys", @@ -1245,9 +1271,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" [[package]] name = "unicode-segmentation" @@ -1267,6 +1293,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "which" version = "4.4.2" @@ -1427,6 +1462,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags", +] + [[package]] name = "witness" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 67defe910..3e93cda87 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,5 +51,6 @@ num-traits = "0.2" rayon = "1" serde = { version = "1.0.130", features = ["derive"] } serde_json = "1.0.68" +serde_arrays = "0.1" serde_derive = "1.0.196" colored = "3" diff --git a/cli/assets/templates/pil_helpers_trace.rs.tt b/cli/assets/templates/pil_helpers_trace.rs.tt index ee6e71472..fc459275b 100644 --- a/cli/assets/templates/pil_helpers_trace.rs.tt +++ b/cli/assets/templates/pil_helpers_trace.rs.tt @@ -1,11 +1,16 @@ // WARNING: This file has been autogenerated from the PILOUT file. // Manual modifications are not recommended and may be overwritten. +#![allow(clippy::all)] +#![allow(non_snake_case)] + use proofman_common as common; pub use proofman_macros::trace; pub use proofman_macros::values; use std::fmt; +use rayon::prelude::*; + #[allow(dead_code)] type FieldExtension = [F; 3]; @@ -23,13 +28,29 @@ pub const { constant.0 }_AIR_IDS: &[usize] = &[{ constant.3 }]; //PUBLICS use serde::Deserialize; use serde::Serialize; -#[derive(Default, Debug, Serialize, Deserialize)] +use serde_arrays; + +{{ for column in public_values.values_u64 }}{{ if column.array }} +fn default_array_{column.name}() -> {column.type} \{ + {column.type_default} +} +{{endif}}{{ endfor }} + +#[derive(Debug, Serialize, Deserialize)] pub struct {project_name}Publics \{ - {{ for column in public_values.values_u64 }}#[serde(default)] + {{ for column in public_values.values_u64 }}{{ if column.array }}#[serde(default = "default_array_{column.name}", with = "serde_arrays")]{{ else }}#[serde(default)]{{ endif }} pub {column.name}: {column.type}, {{ endfor }} } +impl Default for {project_name}Publics \{ + fn default() -> Self \{ + Self \{ {{ for column in public_values.values_default }} + {column.name}: {column.type}, {{ endfor }} + } + } +} + values!({ project_name }PublicValues \{ {{ for column in public_values.values }} { column.name }: { column.type },{{ endfor }} }); @@ -38,6 +59,10 @@ values!({ project_name }ProofValues \{ {{ for column in proof_vals.values }} { column.name }: { column.type },{{ endfor }} }); {{ endfor }} {{ for air_group in air_groups }}{{ for air in air_group.airs }} +trace!({ air.name }Fixed \{ +{{ for column in air.fixed }} { column.name }: { column.type },{{ endfor }} +}, { air_group.airgroup_id }, { air.id }, { air.num_rows } ); + trace!({ air.name }Trace \{ {{ for column in air.columns }} { column.name }: { column.type },{{ endfor }} }, { air_group.airgroup_id }, { air.id }, { air.num_rows } ); diff --git a/cli/src/commands/get_constraints.rs b/cli/src/commands/get_constraints.rs index 4b98e8de6..733aee744 100644 --- a/cli/src/commands/get_constraints.rs +++ b/cli/src/commands/get_constraints.rs @@ -1,5 +1,6 @@ // extern crate env_logger; use clap::Parser; +use p3_goldilocks::Goldilocks; use proofman_common::initialize_logger; use std::path::PathBuf; use colored::Colorize; @@ -24,7 +25,7 @@ impl GetConstraintsCmd { println!(); let global_info = GlobalInfo::new(&self.proving_key); - let setups = Arc::new(SetupsVadcop::new(&global_info, false, false)); + let setups = Arc::new(SetupsVadcop::::new(&global_info, false, false, false)); initialize_logger(proofman_common::VerboseMode::Info); diff --git a/cli/src/commands/pil_helpers.rs b/cli/src/commands/pil_helpers.rs index 46b55dfdf..339df36c7 100644 --- a/cli/src/commands/pil_helpers.rs +++ b/cli/src/commands/pil_helpers.rs @@ -52,6 +52,7 @@ struct AirCtx { name: String, num_rows: u32, columns: Vec, + fixed: Vec, stages_columns: Vec, custom_columns: Vec, air_values: Vec, @@ -61,7 +62,8 @@ struct AirCtx { #[derive(Clone, Debug, Serialize)] struct ValuesCtx { values: Vec, - values_u64: Vec, + values_u64: Vec, + values_default: Vec, } #[derive(Clone, Debug, Serialize)] @@ -76,6 +78,14 @@ struct ColumnCtx { r#type: String, } +#[derive(Clone, Debug, Serialize)] +struct Column64Ctx { + name: String, + array: bool, + r#type: String, + r#type_default: String, +} + #[derive(Default, Clone, Debug, Serialize)] struct StageColumnCtx { stage_id: usize, @@ -135,6 +145,7 @@ impl PilHelpersCmd { name: air.name.as_ref().unwrap().to_string(), num_rows: air.num_rows.unwrap(), columns: Vec::new(), + fixed: Vec::new(), stages_columns: vec![StageColumnCtx::default(); pilout.num_challenges.len() - 1], custom_columns: Vec::new(), air_values: Vec::new(), @@ -192,7 +203,11 @@ impl PilHelpersCmd { }; if symbol.r#type == SymbolType::ProofValue as i32 { if proof_values.is_empty() { - proof_values.push(ValuesCtx { values: Vec::new(), values_u64: Vec::new() }); + proof_values.push(ValuesCtx { + values: Vec::new(), + values_u64: Vec::new(), + values_default: Vec::new(), + }); } if symbol.stage == Some(1) { proof_values[0].values.push(ColumnCtx { name: name.to_owned(), r#type }); @@ -201,7 +216,11 @@ impl PilHelpersCmd { } } else { if publics.is_empty() { - publics.push(ValuesCtx { values: Vec::new(), values_u64: Vec::new() }); + publics.push(ValuesCtx { + values: Vec::new(), + values_u64: Vec::new(), + values_default: Vec::new(), + }); } publics[0].values.push(ColumnCtx { name: name.to_owned(), r#type }); let r#type_64 = if symbol.lengths.is_empty() { @@ -214,7 +233,20 @@ impl PilHelpersCmd { .rev() .fold("u64".to_string(), |acc, &length| format!("[{}; {}]", acc, length)) }; - publics[0].values_u64.push(ColumnCtx { name: name.to_owned(), r#type: r#type_64 }); + let default = "0".to_string(); + let r#type_default = if symbol.lengths.is_empty() { + default // Case when lengths.len() == 0 + } else { + // Start with "u64" and apply each length in reverse order + symbol.lengths.iter().rev().fold(default, |acc, &length| format!("[{}; {}]", acc, length)) + }; + publics[0].values_u64.push(Column64Ctx { + name: name.to_owned(), + r#type: r#type_64, + r#type_default: r#type_default.clone(), + array: !symbol.lengths.is_empty(), + }); + publics[0].values_default.push(ColumnCtx { name: name.to_owned(), r#type: r#type_default }); } }); @@ -244,6 +276,7 @@ impl PilHelpersCmd { || symbol.r#type == SymbolType::AirGroupValue as i32) && symbol.stage.is_some() && ((symbol.r#type == SymbolType::WitnessCol as i32) + || (symbol.r#type == SymbolType::FixedCol as i32) || (symbol.r#type == SymbolType::AirValue as i32) || (symbol.r#type == SymbolType::AirGroupValue as i32) || (symbol.r#type == SymbolType::CustomCol as i32 && symbol.stage.unwrap() == 0)) @@ -280,9 +313,15 @@ impl PilHelpersCmd { .columns .push(ColumnCtx { name: name.to_owned(), r#type: ext_type }); } + } else if symbol.r#type == SymbolType::FixedCol as i32 { + air.fixed.push(ColumnCtx { name: name.to_owned(), r#type }); } else if symbol.r#type == SymbolType::AirValue as i32 { if air.air_values.is_empty() { - air.air_values.push(ValuesCtx { values: Vec::new(), values_u64: Vec::new() }); + air.air_values.push(ValuesCtx { + values: Vec::new(), + values_u64: Vec::new(), + values_default: Vec::new(), + }); } if symbol.stage == Some(1) { air.air_values[0].values.push(ColumnCtx { name: name.to_owned(), r#type }); @@ -291,7 +330,11 @@ impl PilHelpersCmd { } } else if symbol.r#type == SymbolType::AirGroupValue as i32 { if air.airgroup_values.is_empty() { - air.airgroup_values.push(ValuesCtx { values: Vec::new(), values_u64: Vec::new() }); + air.airgroup_values.push(ValuesCtx { + values: Vec::new(), + values_u64: Vec::new(), + values_default: Vec::new(), + }); } air.airgroup_values[0].values.push(ColumnCtx { name: name.to_owned(), r#type: ext_type }); } else { diff --git a/cli/src/commands/prove.rs b/cli/src/commands/prove.rs index 64093d6e1..a99e51460 100644 --- a/cli/src/commands/prove.rs +++ b/cli/src/commands/prove.rs @@ -26,8 +26,12 @@ pub struct ProveCmd { #[clap(short, long)] pub rom: Option, - /// Public inputs path + /// Inputs path #[clap(short = 'i', long)] + pub input_data: Option, + + /// Public inputs path + #[clap(short = 'p', long)] pub public_inputs: Option, /// Setup folder path @@ -89,6 +93,7 @@ impl ProveCmd { self.witness_lib.clone(), self.rom.clone(), self.public_inputs.clone(), + self.input_data.clone(), self.proving_key.clone(), self.output_dir.clone(), ProofOptions::new(false, self.verbose.into(), self.aggregation, self.final_snark, debug_info), diff --git a/cli/src/commands/verify_constraints.rs b/cli/src/commands/verify_constraints.rs index 3fec60ee6..3c3a371ab 100644 --- a/cli/src/commands/verify_constraints.rs +++ b/cli/src/commands/verify_constraints.rs @@ -24,8 +24,12 @@ pub struct VerifyConstraintsCmd { #[clap(short, long)] pub rom: Option, - /// Public inputs path + /// Inputs path #[clap(short = 'i', long)] + pub input_data: Option, + + /// Public inputs path + #[clap(short = 'p', long)] pub public_inputs: Option, /// Setup folder path @@ -61,6 +65,7 @@ impl VerifyConstraintsCmd { self.witness_lib.clone(), self.rom.clone(), self.public_inputs.clone(), + self.input_data.clone(), self.proving_key.clone(), PathBuf::new(), ProofOptions::new(true, self.verbose.into(), false, false, debug_info), diff --git a/common/src/air_instances_repository.rs b/common/src/air_instances_repository.rs index 216f7aaee..def6d7423 100644 --- a/common/src/air_instances_repository.rs +++ b/common/src/air_instances_repository.rs @@ -1,4 +1,5 @@ use std::{collections::HashMap, sync::RwLock}; +use rayon::prelude::*; use p3_field::Field; @@ -31,10 +32,10 @@ impl AirInstancesRepository { pub fn free_traces(&self) { let mut air_instances = self.air_instances.write().unwrap(); - for (_, air_instance) in air_instances.iter_mut() { + air_instances.par_iter_mut().for_each(|(_, air_instance)| { air_instance.clear_trace(); air_instance.clear_custom_commits_trace(); - } + }); } pub fn find_airgroup_instances(&self, airgroup_id: usize) -> Vec { @@ -61,20 +62,4 @@ impl AirInstancesRepository { indices } - - pub fn find_instance(&self, airgroup_id: usize, air_id: usize, air_instance_id: usize) -> Option { - let air_instances = self.air_instances.read().unwrap(); - - let mut count = 0; - for (index, air_instance) in air_instances.iter() { - if air_instance.airgroup_id == airgroup_id && air_instance.air_id == air_id { - count += 1; - if count == air_instance_id { - return Some(*index); - } - } - } - - None - } } diff --git a/common/src/constraints.rs b/common/src/constraints.rs index 315208118..de555823e 100644 --- a/common/src/constraints.rs +++ b/common/src/constraints.rs @@ -36,7 +36,7 @@ pub struct GlobalConstraintInfo { pub value: [u64; 3usize], } -pub fn get_constraints_lines_str(sctx: Arc, airgroup_id: usize, air_id: usize) -> Vec { +pub fn get_constraints_lines_str(sctx: Arc>, airgroup_id: usize, air_id: usize) -> Vec { let setup = sctx.get_setup(airgroup_id, air_id); let p_setup = (&setup.p_setup).into(); @@ -64,7 +64,7 @@ pub fn get_constraints_lines_str(sctx: Arc, airgroup_id: usize, air_id constraints_lines_str } -pub fn get_global_constraints_lines_str(sctx: Arc) -> Vec { +pub fn get_global_constraints_lines_str(sctx: Arc>) -> Vec { let n_global_constraints = get_n_global_constraints_c(sctx.get_global_bin()); let mut global_constraints_sizes = vec![0u64; n_global_constraints as usize]; diff --git a/common/src/custom_commits.rs b/common/src/custom_commits.rs index d3e7b42c3..c58c3bdda 100644 --- a/common/src/custom_commits.rs +++ b/common/src/custom_commits.rs @@ -6,7 +6,7 @@ use crate::Setup; pub fn get_custom_commit_trace( commit_id: u64, step: u64, - setup: &Setup, + setup: &Setup, buffer: Vec, buffer_ext: Vec, buffer_str: &str, @@ -17,7 +17,7 @@ pub fn get_custom_commit_trace( step, buffer.as_ptr() as *mut u8, buffer_ext.as_ptr() as *mut u8, - fri_proof_new_c((&setup.p_setup).into(), 0), + fri_proof_new_c((&setup.p_setup).into(), 0, 0, 0), std::ptr::null_mut(), buffer_str, ); diff --git a/common/src/distribution_ctx.rs b/common/src/distribution_ctx.rs index 3a39247e7..caa2dc0eb 100644 --- a/common/src/distribution_ctx.rs +++ b/common/src/distribution_ctx.rs @@ -156,6 +156,36 @@ impl DistributionCtx { } } + #[inline] + pub fn find_instance_id(&self, airgroup_id: usize, air_id: usize, air_instance_id: usize) -> Option { + let mut count = 0; + for (global_idx, instance) in self.instances.iter().enumerate() { + if instance == &(airgroup_id, air_id) { + if count == air_instance_id { + return Some(global_idx); + } + count += 1; + } + } + None + } + + #[inline] + pub fn is_min_rank_owner(&self, airgroup_id: usize, air_id: usize) -> bool { + let mut min_owner = self.n_processes + 1; + for (idx, instance) in self.instances.iter().enumerate() { + if instance == &(airgroup_id, air_id) && self.instances_owner[idx].0 < min_owner { + min_owner = self.instances_owner[idx].0; + } + } + + if min_owner == self.n_processes + 1 { + panic!("No instance found for airgroup_id: {}, air_id: {}", airgroup_id, air_id); + } + + min_owner == self.rank + } + #[inline] pub fn add_instance(&mut self, airgroup_id: usize, air_id: usize, weight: u64) -> (bool, usize) { let mut is_mine = false; diff --git a/common/src/fixed_cols.rs b/common/src/fixed_cols.rs new file mode 100644 index 000000000..290af8e8e --- /dev/null +++ b/common/src/fixed_cols.rs @@ -0,0 +1,54 @@ +use std::os::raw::c_void; + +use p3_field::Field; +use proofman_starks_lib_c::write_fixed_cols_bin_c; + +#[repr(C)] +#[derive(Debug)] +pub struct FixedColsInfoC { + name_size: u64, + name: *mut u8, + n_lengths: u64, + lengths: *mut u64, + values: *mut F, +} + +impl FixedColsInfoC { + pub fn from_fixed_cols_info_vec(fixed_cols: &mut [FixedColsInfo]) -> Vec> { + fixed_cols + .iter_mut() + .map(|info| FixedColsInfoC { + name_size: info.name.len() as u64, + name: info.name.as_mut_ptr(), + n_lengths: info.lengths.len() as u64, + lengths: info.lengths.as_mut_ptr(), + values: info.values.as_mut_ptr(), + }) + .collect() + } +} +#[derive(Clone, Debug)] +#[repr(C)] +pub struct FixedColsInfo { + name: String, // AirName.ColumnName + lengths: Vec, + values: Vec, +} + +impl FixedColsInfo { + pub fn new(name_: &str, lengths: Option>, values: Vec) -> Self { + FixedColsInfo { name: name_.to_string(), lengths: lengths.unwrap_or_default(), values } + } +} + +pub fn write_fixed_cols_bin( + bin_file: &str, + airgroup_name: &str, + air_name: &str, + n: u64, + fixed_cols: &mut [FixedColsInfo], +) { + let mut fixed_cols_info_c = FixedColsInfoC::::from_fixed_cols_info_vec(fixed_cols); + let fixed_cols_info_c_ptr = fixed_cols_info_c.as_mut_ptr() as *mut c_void; + write_fixed_cols_bin_c(bin_file, airgroup_name, air_name, n, fixed_cols.len() as u64, fixed_cols_info_c_ptr); +} diff --git a/common/src/lib.rs b/common/src/lib.rs index 44ca7b344..b7fd251f4 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -16,6 +16,7 @@ pub mod utils; pub mod custom_commits; pub mod constraints; pub mod prover_helpers; +pub mod fixed_cols; pub use air_instance::*; pub use air_instances_repository::*; @@ -34,3 +35,4 @@ pub use distribution_ctx::*; pub use custom_commits::*; pub use constraints::*; pub use prover_helpers::*; +pub use fixed_cols::*; diff --git a/common/src/proof_ctx.rs b/common/src/proof_ctx.rs index 2612ed958..01e3637ed 100644 --- a/common/src/proof_ctx.rs +++ b/common/src/proof_ctx.rs @@ -103,7 +103,29 @@ impl ProofCtx { } } - pub fn set_weights(&mut self, sctx: &SetupCtx) { + pub fn create_ctx_agg( + global_info: &GlobalInfo, + options: ProofOptions, + public_inputs: Vec, + challenges: Vec, + proof_values: Vec, + dctx: DistributionCtx, + weights: HashMap<(usize, usize), u64>, + ) -> Self { + Self { + global_info: global_info.clone(), + public_inputs: Values { values: RwLock::new(public_inputs) }, + proof_values: Values { values: RwLock::new(proof_values) }, + challenges: Values { values: RwLock::new(challenges) }, + buff_helper: Values::default(), + air_instance_repo: AirInstancesRepository::new(), + dctx: RwLock::new(dctx), + weights, + options, + } + } + + pub fn set_weights(&mut self, sctx: &SetupCtx) { for (airgroup_id, air_group) in self.global_info.airs.iter().enumerate() { for (air_id, _) in air_group.iter().enumerate() { let setup = sctx.get_setup(airgroup_id, air_id); @@ -136,6 +158,16 @@ impl ProofCtx { self.air_instance_repo.add_air_instance(air_instance, global_idx); } + pub fn dctx_barrier(&self) { + let dctx = self.dctx.read().unwrap(); + dctx.barrier(); + } + + pub fn dctx_is_min_rank_owner(&self, airgroup_id: usize, air_id: usize) -> bool { + let dctx = self.dctx.read().unwrap(); + dctx.is_min_rank_owner(airgroup_id, air_id) + } + pub fn dctx_get_rank(&self) -> usize { let dctx = self.dctx.read().unwrap(); dctx.rank as usize @@ -310,7 +342,8 @@ impl ProofCtx { } pub fn get_air_instance_trace(&self, airgroup_id: usize, air_id: usize, air_instance_id: usize) -> Vec { - let index = self.air_instance_repo.find_instance(airgroup_id, air_id, air_instance_id); + let dctx = self.dctx.read().unwrap(); + let index = dctx.find_instance_id(airgroup_id, air_id, air_instance_id); if let Some(index) = index { return self.air_instance_repo.air_instances.read().unwrap().get(&index).unwrap().get_trace(); } else { @@ -322,9 +355,10 @@ impl ProofCtx { } pub fn get_air_instance_air_values(&self, airgroup_id: usize, air_id: usize, air_instance_id: usize) -> Vec { - let index = self.air_instance_repo.find_instance(airgroup_id, air_id, air_instance_id); + let dctx = self.dctx.read().unwrap(); + let index = dctx.find_instance_id(airgroup_id, air_id, air_instance_id); if let Some(index) = index { - return self.air_instance_repo.air_instances.read().unwrap().get(&index).unwrap().get_trace(); + return self.air_instance_repo.air_instances.read().unwrap().get(&index).unwrap().get_air_values(); } else { panic!( "Air Instance with id {} for airgroup {} and air {} not found", @@ -339,9 +373,10 @@ impl ProofCtx { air_id: usize, air_instance_id: usize, ) -> Vec { - let index = self.air_instance_repo.find_instance(airgroup_id, air_id, air_instance_id); + let dctx = self.dctx.read().unwrap(); + let index = dctx.find_instance_id(airgroup_id, air_id, air_instance_id); if let Some(index) = index { - return self.air_instance_repo.air_instances.read().unwrap().get(&index).unwrap().get_trace(); + return self.air_instance_repo.air_instances.read().unwrap().get(&index).unwrap().get_airgroup_values(); } else { panic!( "Air Instance with id {} for airgroup {} and air {} not found", diff --git a/common/src/prover.rs b/common/src/prover.rs index 8b0e12799..d7e9b2ab4 100644 --- a/common/src/prover.rs +++ b/common/src/prover.rs @@ -36,13 +36,13 @@ pub trait Prover { fn free(&mut self); fn new_transcript(&self) -> FFITranscript; fn num_stages(&self) -> u32; - fn get_challenges(&self, stage_id: u32, pctx: Arc>, transcript: &FFITranscript); - fn calculate_stage(&mut self, stage_id: u32, sctx: Arc, pctx: Arc>); + fn get_challenges(&self, stage_id: u32, pctx: Arc>, transcript: &FFITranscript) -> Vec>; + fn calculate_stage(&mut self, stage_id: u32, sctx: Arc>, pctx: Arc>); fn commit_stage(&mut self, stage_id: u32, pctx: Arc>) -> ProverStatus; fn commit_custom_commits_stage(&mut self, stage_id: u32, pctx: Arc>) -> Vec; - fn calculate_xdivxsub(&mut self, pctx: Arc>); - fn calculate_lev(&mut self, pctx: Arc>); - fn opening_stage(&mut self, opening_id: u32, sctx: Arc, pctx: Arc>) -> ProverStatus; + fn calculate_xdivxsub(&mut self, pctx: Arc>, challenge: Vec); + fn calculate_lev(&mut self, pctx: Arc>, challenge: Vec); + fn opening_stage(&mut self, opening_id: u32, sctx: Arc>, pctx: Arc>) -> ProverStatus; fn get_buff_helper_size(&self, pctx: Arc>) -> usize; fn get_proof(&self) -> *mut c_void; @@ -53,7 +53,7 @@ pub trait Prover { fn get_transcript_values(&self, stage: u64, pctx: Arc>) -> Vec; fn get_transcript_values_u64(&self, stage: u64, pctx: Arc>) -> Vec; fn calculate_hash(&self, values: Vec) -> Vec; - fn verify_constraints(&self, sctx: Arc, pctx: Arc>) -> Vec; + fn verify_constraints(&self, sctx: Arc>, pctx: Arc>) -> Vec; fn get_proof_challenges(&self, global_steps: Vec, global_challenges: Vec) -> Vec; } diff --git a/common/src/setup.rs b/common/src/setup.rs index f3b9f6be3..9d508cc31 100644 --- a/common/src/setup.rs +++ b/common/src/setup.rs @@ -1,14 +1,14 @@ use std::os::raw::c_void; use std::path::PathBuf; -use std::sync::RwLock; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; -use proofman_starks_lib_c::get_map_totaln_c; use proofman_starks_lib_c::{ - get_const_tree_size_c, get_const_size_c, prover_helpers_new_c, expressions_bin_new_c, stark_info_new_c, - load_const_tree_c, load_const_pols_c, calculate_const_tree_c, stark_info_free_c, expressions_bin_free_c, - prover_helpers_free_c, + get_const_tree_size_c, prover_helpers_new_c, expressions_bin_new_c, stark_info_new_c, load_const_tree_c, + load_const_pols_c, calculate_const_tree_c, stark_info_free_c, expressions_bin_free_c, prover_helpers_free_c, + get_map_totaln_c, write_const_tree_c, }; -use proofman_util::create_buffer_fast_u8; +use proofman_util::create_buffer_fast; use crate::GlobalInfo; use crate::ProofType; @@ -31,34 +31,33 @@ impl From<&SetupC> for *mut c_void { } } -#[derive(Debug)] -pub struct Pols { - pub values: RwLock>, -} - -impl Default for Pols { - fn default() -> Self { - Self { values: RwLock::new(Vec::new()) } - } -} - /// Air instance context for managing air instances (traces) #[derive(Debug)] #[allow(dead_code)] -pub struct Setup { +pub struct Setup { pub airgroup_id: usize, pub air_id: usize, pub p_setup: SetupC, pub stark_info: StarkInfo, - pub const_pols: Pols, - pub const_tree: Pols, + pub const_pols: Vec, + pub const_tree: Vec, pub prover_buffer_size: u64, + pub write_const_tree: AtomicBool, + pub setup_path: PathBuf, + pub setup_type: ProofType, + pub air_name: String, } -impl Setup { +impl Setup { const MY_NAME: &'static str = "Setup"; - pub fn new(global_info: &GlobalInfo, airgroup_id: usize, air_id: usize, setup_type: &ProofType) -> Self { + pub fn new( + global_info: &GlobalInfo, + airgroup_id: usize, + air_id: usize, + setup_type: &ProofType, + verify_constraints: bool, + ) -> Self { let setup_path = match setup_type { ProofType::VadcopFinal => global_info.get_setup_path("vadcop_final"), ProofType::RecursiveF => global_info.get_setup_path("recursivef"), @@ -68,32 +67,63 @@ impl Setup { let stark_info_path = setup_path.display().to_string() + ".starkinfo.json"; let expressions_bin_path = setup_path.display().to_string() + ".bin"; - let (stark_info, p_stark_info, p_expressions_bin, p_prover_helpers, prover_buffer_size) = - if setup_type == &ProofType::Compressor && !global_info.get_air_has_compressor(airgroup_id, air_id) { - // If the condition is met, use None for each pointer - (StarkInfo::default(), std::ptr::null_mut(), std::ptr::null_mut(), std::ptr::null_mut(), 0) + let ( + stark_info, + p_stark_info, + p_expressions_bin, + p_prover_helpers, + prover_buffer_size, + const_pols_size, + const_tree_size, + ) = if setup_type == &ProofType::Compressor && !global_info.get_air_has_compressor(airgroup_id, air_id) { + // If the condition is met, use None for each pointer + (StarkInfo::default(), std::ptr::null_mut(), std::ptr::null_mut(), std::ptr::null_mut(), 0, 0, 0) + } else { + // Otherwise, initialize the pointers with their respective values + let stark_info_json = std::fs::read_to_string(&stark_info_path) + .unwrap_or_else(|_| panic!("Failed to read file {}", &stark_info_path)); + let stark_info = StarkInfo::from_json(&stark_info_json); + let p_stark_info = stark_info_new_c(stark_info_path.as_str(), false); + let recursive = &ProofType::Basic != setup_type; + let prover_buffer_size = if verify_constraints { + let mut mem_instance = 0; + for stage in 0..stark_info.n_stages + 1 { + let n_cols = stark_info.map_sections_n[&format!("cm{}", stage + 1)]; + mem_instance += n_cols * (1 << (stark_info.stark_struct.n_bits)); + } + mem_instance += (stark_info.custom_commits_map.len() * (1 << (stark_info.stark_struct.n_bits))) as u64; + mem_instance } else { - // Otherwise, initialize the pointers with their respective values - let stark_info_json = std::fs::read_to_string(&stark_info_path) - .unwrap_or_else(|_| panic!("Failed to read file {}", &stark_info_path)); - let stark_info = StarkInfo::from_json(&stark_info_json); - let p_stark_info = stark_info_new_c(stark_info_path.as_str(), false); - let recursive = &ProofType::Basic != setup_type; - let prover_buffer_size = get_map_totaln_c(p_stark_info, recursive); - let expressions_bin = expressions_bin_new_c(expressions_bin_path.as_str(), false, false); - let prover_helpers = prover_helpers_new_c(p_stark_info, recursive); - - (stark_info, p_stark_info, expressions_bin, prover_helpers, prover_buffer_size) + get_map_totaln_c(p_stark_info, recursive) }; + let expressions_bin = expressions_bin_new_c(expressions_bin_path.as_str(), false, false); + let prover_helpers = prover_helpers_new_c(p_stark_info, recursive); + let const_pols_size = (stark_info.n_constants * (1 << stark_info.stark_struct.n_bits)) as usize; + let const_pols_tree_size = get_const_tree_size_c(p_stark_info) as usize; + + ( + stark_info, + p_stark_info, + expressions_bin, + prover_helpers, + prover_buffer_size, + const_pols_size, + const_pols_tree_size, + ) + }; Self { air_id, airgroup_id, stark_info, p_setup: SetupC { p_stark_info, p_expressions_bin, p_prover_helpers }, - const_pols: Pols::default(), - const_tree: Pols::default(), + const_pols: create_buffer_fast(const_pols_size), + const_tree: create_buffer_fast(const_tree_size), prover_buffer_size, + write_const_tree: AtomicBool::new(false), + setup_path: setup_path.clone(), + setup_type: setup_type.clone(), + air_name: global_info.airs[airgroup_id][air_id].name.clone(), } } @@ -103,67 +133,80 @@ impl Setup { prover_helpers_free_c(self.p_setup.p_prover_helpers); } - pub fn load_const_pols(&self, global_info: &GlobalInfo, setup_type: &ProofType) { - let setup_path = match setup_type { - ProofType::VadcopFinal => global_info.get_setup_path("vadcop_final"), - ProofType::RecursiveF => global_info.get_setup_path("recursivef"), - _ => global_info.get_air_setup_path(self.airgroup_id, self.air_id, setup_type), - }; + pub fn load_const_pols(&self) { + log::debug!( + "{} : ··· Loading const pols for AIR {} of type {:?}", + Self::MY_NAME, + self.air_name, + self.setup_type + ); + + let const_pols_path = self.setup_path.display().to_string() + ".const"; + + load_const_pols_c( + self.const_pols.as_ptr() as *mut u8, + const_pols_path.as_str(), + self.const_pols.len() as u64 * 8, + ); + } - let air_name = &global_info.airs[self.airgroup_id][self.air_id].name; - log::debug!("{} : ··· Loading const pols for AIR {} of type {:?}", Self::MY_NAME, air_name, setup_type); + pub fn load_const_pols_tree(&self) { + log::debug!( + "{} : ··· Loading const tree for AIR {} of type {:?}", + Self::MY_NAME, + self.air_name, + self.setup_type + ); - let const_pols_path = setup_path.display().to_string() + ".const"; + let const_pols_tree_path = self.setup_path.display().to_string() + ".consttree"; + + let verkey_path = self.setup_path.display().to_string() + ".verkey.json"; let p_stark_info = self.p_setup.p_stark_info; - let const_size = get_const_size_c(p_stark_info) as usize; - let const_pols = create_buffer_fast_u8(const_size); + let valid_root = if PathBuf::from(&const_pols_tree_path).exists() { + load_const_tree_c( + p_stark_info, + self.const_tree.as_ptr() as *mut u8, + const_pols_tree_path.as_str(), + (self.const_tree.len() * 8) as u64, + verkey_path.as_str(), + ) + } else { + false + }; - load_const_pols_c(const_pols.as_ptr() as *mut u8, const_pols_path.as_str(), const_size as u64); - *self.const_pols.values.write().unwrap() = const_pols; + if !valid_root { + calculate_const_tree_c( + p_stark_info, + self.const_pols.as_ptr() as *mut u8, + self.const_tree.as_ptr() as *mut u8, + ); + self.write_const_tree.store(true, Ordering::SeqCst) + }; } - pub fn load_const_pols_tree(&self, global_info: &GlobalInfo, setup_type: &ProofType, save_file: bool) { - let setup_path = match setup_type { - ProofType::VadcopFinal => global_info.get_setup_path("vadcop_final"), - ProofType::RecursiveF => global_info.get_setup_path("recursivef"), - _ => global_info.get_air_setup_path(self.airgroup_id, self.air_id, setup_type), - }; + pub fn to_write_tree(&self) -> bool { + self.write_const_tree.load(Ordering::SeqCst) + } - let air_name = &global_info.airs[self.airgroup_id][self.air_id].name; - log::debug!("{} : ··· Loading const tree for AIR {} of type {:?}", Self::MY_NAME, air_name, setup_type); + pub fn set_write_const_tree(&self, write: bool) { + self.write_const_tree.store(write, Ordering::SeqCst) + } - let const_pols_tree_path = setup_path.display().to_string() + ".consttree"; + pub fn write_const_tree(&self) { + let const_pols_tree_path = self.setup_path.display().to_string() + ".consttree"; let p_stark_info = self.p_setup.p_stark_info; - let const_tree_size = get_const_tree_size_c(p_stark_info) as usize; - - let const_tree = create_buffer_fast_u8(const_tree_size); - - if PathBuf::from(&const_pols_tree_path).exists() { - load_const_tree_c(const_tree.as_ptr() as *mut u8, const_pols_tree_path.as_str(), const_tree_size as u64); - } else { - let const_pols = self.const_pols.values.read().unwrap(); - let tree_filename = if save_file { const_pols_tree_path.as_str() } else { "" }; - calculate_const_tree_c( - p_stark_info, - (*const_pols).as_ptr() as *mut u8, - const_tree.as_ptr() as *mut u8, - tree_filename, - ); - }; - *self.const_tree.values.write().unwrap() = const_tree; + write_const_tree_c(p_stark_info, self.const_tree.as_ptr() as *mut u8, const_pols_tree_path.as_str()); } pub fn get_const_ptr(&self) -> *mut u8 { - let guard = &self.const_pols.values.read().unwrap(); - guard.as_ptr() as *mut u8 + self.const_pols.as_ptr() as *mut u8 } pub fn get_const_tree_ptr(&self) -> *mut u8 { - let guard = &self.const_tree.values.read().unwrap(); - guard.as_ptr() as *mut u8 + self.const_tree.as_ptr() as *mut u8 } } diff --git a/common/src/setup_ctx.rs b/common/src/setup_ctx.rs index 49dc859d6..786b26b60 100644 --- a/common/src/setup_ctx.rs +++ b/common/src/setup_ctx.rs @@ -10,20 +10,20 @@ use crate::GlobalInfo; use crate::Setup; use crate::ProofType; -pub struct SetupsVadcop { - pub sctx: Arc, - pub sctx_compressor: Option>, - pub sctx_recursive1: Option>, - pub sctx_recursive2: Option>, - pub setup_vadcop_final: Option>, - pub setup_recursivef: Option>, +pub struct SetupsVadcop { + pub sctx: Arc>, + pub sctx_compressor: Option>>, + pub sctx_recursive1: Option>>, + pub sctx_recursive2: Option>>, + pub setup_vadcop_final: Option>>, + pub setup_recursivef: Option>>, } -impl SetupsVadcop { - pub fn new(global_info: &GlobalInfo, aggregation: bool, final_snark: bool) -> Self { +impl SetupsVadcop { + pub fn new(global_info: &GlobalInfo, verify_constraints: bool, aggregation: bool, final_snark: bool) -> Self { info!("Initializing setups"); timer_start_info!(INITIALIZING_BASIC_SETUP); - let sctx: SetupCtx = SetupCtx::new(global_info, &ProofType::Basic); + let sctx = SetupCtx::::new(global_info, &ProofType::Basic, verify_constraints); timer_stop_and_log_info!(INITIALIZING_BASIC_SETUP); if aggregation { timer_start_info!(INITIALIZING_AGGREGATION_SETUP); @@ -31,22 +31,22 @@ impl SetupsVadcop { timer_start_debug!(INITIALIZING_SETUP_COMPRESSOR); info!(" ··· Initializing setups compressor"); - let sctx_compressor: SetupCtx = SetupCtx::new(global_info, &ProofType::Compressor); + let sctx_compressor = SetupCtx::::new(global_info, &ProofType::Compressor, false); timer_stop_and_log_debug!(INITIALIZING_SETUP_COMPRESSOR); timer_start_debug!(INITIALIZING_SETUP_RECURSIVE1); info!(" ··· Initializing setups recursive1"); - let sctx_recursive1: SetupCtx = SetupCtx::new(global_info, &ProofType::Recursive1); + let sctx_recursive1 = SetupCtx::::new(global_info, &ProofType::Recursive1, false); timer_stop_and_log_debug!(INITIALIZING_SETUP_RECURSIVE1); timer_start_debug!(INITIALIZING_SETUP_RECURSIVE2); info!(" ··· Initializing setups recursive2"); - let sctx_recursive2: SetupCtx = SetupCtx::new(global_info, &ProofType::Recursive2); + let sctx_recursive2 = SetupCtx::::new(global_info, &ProofType::Recursive2, false); timer_stop_and_log_debug!(INITIALIZING_SETUP_RECURSIVE2); timer_start_debug!(INITIALIZING_SETUP_VADCOP_FINAL); info!(" ··· Initializing setups vadcop final"); - let setup_vadcop_final: Setup = Setup::new(global_info, 0, 0, &ProofType::VadcopFinal); + let setup_vadcop_final = Setup::::new(global_info, 0, 0, &ProofType::VadcopFinal, verify_constraints); timer_stop_and_log_debug!(INITIALIZING_SETUP_VADCOP_FINAL); timer_stop_and_log_info!(INITIALIZING_AGGREGATION_SETUP); @@ -55,7 +55,8 @@ impl SetupsVadcop { timer_start_debug!(INITIALIZING_SETUP_RECURSION); timer_start_debug!(INITIALIZING_SETUP_RECURSIVEF); info!(" ··· Initializing setups recursivef"); - setup_recursivef = Some(Arc::new(Setup::new(global_info, 0, 0, &ProofType::RecursiveF))); + setup_recursivef = + Some(Arc::new(Setup::::new(global_info, 0, 0, &ProofType::RecursiveF, verify_constraints))); timer_stop_and_log_debug!(INITIALIZING_SETUP_RECURSIVEF); timer_stop_and_log_debug!(INITIALIZING_SETUP_RECURSION); } @@ -82,17 +83,17 @@ impl SetupsVadcop { } #[derive(Debug)] -pub struct SetupRepository { - setups: HashMap<(usize, usize), Setup>, +pub struct SetupRepository { + setups: HashMap<(usize, usize), Setup>, global_bin: Option<*mut c_void>, global_info_file: String, } -unsafe impl Send for SetupRepository {} -unsafe impl Sync for SetupRepository {} +unsafe impl Send for SetupRepository {} +unsafe impl Sync for SetupRepository {} -impl SetupRepository { - pub fn new(global_info: &GlobalInfo, setup_type: &ProofType) -> Self { +impl SetupRepository { + pub fn new(global_info: &GlobalInfo, setup_type: &ProofType, verify_constraints: bool) -> Self { let mut setups = HashMap::new(); let global_bin = match setup_type == &ProofType::Basic { @@ -111,11 +112,14 @@ impl SetupRepository { if setup_type != &ProofType::VadcopFinal { for (airgroup_id, air_group) in global_info.airs.iter().enumerate() { for (air_id, _) in air_group.iter().enumerate() { - setups.insert((airgroup_id, air_id), Setup::new(global_info, airgroup_id, air_id, setup_type)); + setups.insert( + (airgroup_id, air_id), + Setup::new(global_info, airgroup_id, air_id, setup_type, verify_constraints), + ); } } } else { - setups.insert((0, 0), Setup::new(global_info, 0, 0, setup_type)); + setups.insert((0, 0), Setup::new(global_info, 0, 0, setup_type, verify_constraints)); } Self { setups, global_bin, global_info_file } @@ -127,17 +131,20 @@ impl SetupRepository { } /// Air instance context for managing air instances (traces) #[allow(dead_code)] -pub struct SetupCtx { - setup_repository: SetupRepository, +pub struct SetupCtx { + setup_repository: SetupRepository, setup_type: ProofType, } -impl SetupCtx { - pub fn new(global_info: &GlobalInfo, setup_type: &ProofType) -> Self { - SetupCtx { setup_repository: SetupRepository::new(global_info, setup_type), setup_type: setup_type.clone() } +impl SetupCtx { + pub fn new(global_info: &GlobalInfo, setup_type: &ProofType, verify_constraints: bool) -> Self { + SetupCtx { + setup_repository: SetupRepository::::new(global_info, setup_type, verify_constraints), + setup_type: setup_type.clone(), + } } - pub fn get_setup(&self, airgroup_id: usize, air_id: usize) -> &Setup { + pub fn get_setup(&self, airgroup_id: usize, air_id: usize) -> &Setup { match self.setup_repository.setups.get(&(airgroup_id, air_id)) { Some(setup) => setup, None => { @@ -149,6 +156,18 @@ impl SetupCtx { } } + pub fn get_fixed_slice(&self, airgroup_id: usize, air_id: usize) -> &[F] { + match self.setup_repository.setups.get(&(airgroup_id, air_id)) { + Some(setup) => setup.const_pols.as_slice(), + None => { + // Handle the error case as needed + log::error!("Setup not found for airgroup_id: {}, air_id: {}", airgroup_id, air_id); + // You might want to return a default value or panic + panic!("Setup not found"); // or return a default value if applicable + } + } + } + pub fn get_setups_list(&self) -> Vec<(usize, usize)> { self.setup_repository.setups.keys().cloned().collect() } diff --git a/common/src/std_mode.rs b/common/src/std_mode.rs index e4858083a..ca6316074 100644 --- a/common/src/std_mode.rs +++ b/common/src/std_mode.rs @@ -10,27 +10,28 @@ pub struct StdMode { pub opids: Vec, pub n_vals: usize, pub print_to_file: bool, + pub fast_mode: bool, } impl StdMode { - pub const fn new(name: ModeName, opids: Vec, n_vals: usize, print_to_file: bool) -> Self { + pub const fn new(name: ModeName, opids: Vec, n_vals: usize, print_to_file: bool, fast_mode: bool) -> Self { if name.as_usize() != ModeName::Standard.as_usize() && n_vals == 0 { panic!("n_vals must be greater than 0"); } - Self { name, opids, n_vals, print_to_file } + Self { name, opids, n_vals, print_to_file, fast_mode } } pub fn new_debug() -> Self { - Self::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false) + Self::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false, true) } } impl From for StdMode { fn from(v: u8) -> Self { match v { - 0 => StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false), - 1 => StdMode::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false), + 0 => StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false, false), + 1 => StdMode::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false, true), _ => panic!("Invalid mode"), } } @@ -38,7 +39,7 @@ impl From for StdMode { impl Default for StdMode { fn default() -> Self { - StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false) + StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false, false) } } diff --git a/common/src/utils.rs b/common/src/utils.rs index 569331c4b..c5f74f7b1 100644 --- a/common/src/utils.rs +++ b/common/src/utils.rs @@ -1,6 +1,6 @@ use crate::{ - AirGroupMap, AirIdMap, AirInstance, DebugInfo, GlobalInfo, InstanceMap, ModeName, ProofCtx, ProofOptions, StdMode, - VerboseMode, DEFAULT_PRINT_VALS, + AirGroupMap, AirIdMap, AirInstance, DebugInfo, GlobalInfo, InstanceMap, ModeName, ProofCtx, StdMode, VerboseMode, + DEFAULT_PRINT_VALS, }; use proofman_starks_lib_c::set_log_level_c; use std::path::PathBuf; @@ -47,15 +47,16 @@ pub fn format_bytes(mut num_bytes: f64) -> String { format!("{:.2} {}", num_bytes, units[unit_index]) } -pub fn skip_prover_instance( - options: ProofOptions, - airgroup_id: usize, - air_id: usize, - air_instance_id: usize, -) -> (bool, Vec) { - if options.debug_info.debug_instances.is_empty() { +pub fn skip_prover_instance(pctx: &ProofCtx, global_idx: usize) -> (bool, Vec) { + if pctx.options.debug_info.debug_instances.is_empty() { return (false, Vec::new()); - } else if let Some(airgroup_id_map) = options.debug_info.debug_instances.get(&airgroup_id) { + } + + let instances = pctx.dctx_get_instances(); + let (airgroup_id, air_id) = instances[global_idx]; + let air_instance_id = pctx.dctx_find_air_instance_id(global_idx); + + if let Some(airgroup_id_map) = pctx.options.debug_info.debug_instances.get(&airgroup_id) { if airgroup_id_map.is_empty() { return (false, Vec::new()); } else if let Some(air_id_map) = airgroup_id_map.get(&air_id) { @@ -70,6 +71,9 @@ pub fn skip_prover_instance( (true, Vec::new()) } +fn default_fast_mode() -> bool { + true +} #[derive(Debug, Default, Deserialize)] struct StdDebugMode { #[serde(default)] @@ -78,6 +82,8 @@ struct StdDebugMode { n_print: Option, #[serde(default)] print_to_file: bool, + #[serde(default = "default_fast_mode")] + fast_mode: bool, } #[derive(Debug, Deserialize)] @@ -200,14 +206,18 @@ pub fn json_to_debug_instances_map(proving_key_path: PathBuf, json_path: String) let global_constraints = json.global_constraints.unwrap_or_default(); let std_mode = if !airgroup_map.is_empty() { - StdMode::new(ModeName::Standard, Vec::new(), 0, false) + StdMode::new(ModeName::Standard, Vec::new(), 0, false, false) } else { let mode = json.std_mode.unwrap_or_default(); + let fast_mode = + if mode.opids.is_some() && !mode.opids.as_ref().unwrap().is_empty() { false } else { mode.fast_mode }; + StdMode::new( ModeName::Debug, mode.opids.unwrap_or_default(), mode.n_print.unwrap_or(DEFAULT_PRINT_VALS), mode.print_to_file, + fast_mode, ) }; diff --git a/examples/fibonacci-square/Cargo.toml b/examples/fibonacci-square/Cargo.toml index b7576f1fc..a3ca0b6a1 100644 --- a/examples/fibonacci-square/Cargo.toml +++ b/examples/fibonacci-square/Cargo.toml @@ -19,4 +19,9 @@ rayon = "1" serde.workspace = true serde_json.workspace = true +serde_arrays.workspace = true num-bigint = "0.4" + +[features] +default = [] +debug = [] diff --git a/examples/fibonacci-square/src/fibonacci.rs b/examples/fibonacci-square/src/fibonacci.rs index 2563bdec0..b6d7d2733 100644 --- a/examples/fibonacci-square/src/fibonacci.rs +++ b/examples/fibonacci-square/src/fibonacci.rs @@ -1,12 +1,12 @@ use std::sync::Arc; -use proofman_common::{add_air_instance, AirInstance, FromTrace, ProofCtx}; +use proofman_common::{add_air_instance, AirInstance, FromTrace, ProofCtx, SetupCtx}; use witness::WitnessComponent; use p3_field::PrimeField64; use crate::{ - FibonacciSquareRomTrace, BuildPublicValues, BuildProofValues, FibonacciSquareAirValues, FibonacciSquareTrace, + BuildProofValues, BuildPublicValues, FibonacciSquareAirValues, FibonacciSquareRomTrace, FibonacciSquareTrace, Module, }; @@ -70,8 +70,9 @@ impl WitnessComponent for FibonacciSquare { add_air_instance::(air_instance, pctx.clone()); } - fn debug(&self, _pctx: Arc>) { - // let trace = FibonacciSquareTrace::from_vec(pctx.get_air_instance_trace(0, 0, 0)); + fn debug(&self, _pctx: Arc>, _sctx: Arc>) { + // let trace = FibonacciSquareTrace::from_vec(_pctx.get_air_instance_trace(0, 0, 0)); + // let fixed = FibonacciSquareFixed::from_vec(_sctx.get_fixed(0, 0)); // let air_values = FibonacciSquareAirValues::from_vec(pctx.get_air_instance_air_values(0, 0, 0)); // let airgroup_values = FibonacciSquareAirGroupValues::from_vec(pctx.get_air_instance_airgroup_values(0, 0, 0)); diff --git a/examples/fibonacci-square/src/module.rs b/examples/fibonacci-square/src/module.rs index a5d65eded..2c8f7c14d 100644 --- a/examples/fibonacci-square/src/module.rs +++ b/examples/fibonacci-square/src/module.rs @@ -47,6 +47,8 @@ impl WitnessComponent for Module { let num_instances = inputs.len().div_ceil(num_rows); for j in 0..num_instances { + let mut x_mods = Vec::new(); + let mut trace = ModuleTrace::new_zeroes(); let inputs_slice = if j < num_instances - 1 { @@ -63,21 +65,24 @@ impl WitnessComponent for Module { trace[i].x = F::from_canonical_u64(x); trace[i].q = F::from_canonical_u64(q); trace[i].x_mod = F::from_canonical_u64(x_mod); - - // Check if x_mod is in the range - self.std_lib.range_check(F::from_canonical_u64(module - x_mod), F::one(), range); - } - - // Trivial range check for the remaining rows - for _ in inputs_slice.len()..trace.num_rows() { - self.std_lib.range_check(F::from_canonical_u64(module), F::one(), range); + x_mods.push(x_mod); } let mut air_values = ModuleAirValues::::new(); air_values.last_segment = F::from_bool(j == num_instances - 1); let air_instance = AirInstance::new_from_trace(FromTrace::new(&mut trace).with_air_values(&mut air_values)); - add_air_instance::(air_instance, pctx.clone()); + let is_mine = add_air_instance::(air_instance, pctx.clone()); + if is_mine { + for x_mod in x_mods.iter() { + self.std_lib.range_check(F::from_canonical_u64(module - x_mod), F::one(), range); + } + + // Trivial range check for the remaining rows + for _ in inputs_slice.len()..trace.num_rows() { + self.std_lib.range_check(F::from_canonical_u64(module), F::one(), range); + } + } } } } diff --git a/examples/fibonacci-square/src/pil_helpers/traces.rs b/examples/fibonacci-square/src/pil_helpers/traces.rs index 83a6d2bc3..75288f261 100644 --- a/examples/fibonacci-square/src/pil_helpers/traces.rs +++ b/examples/fibonacci-square/src/pil_helpers/traces.rs @@ -1,11 +1,16 @@ // WARNING: This file has been autogenerated from the PILOUT file. // Manual modifications are not recommended and may be overwritten. +#![allow(clippy::all)] +#![allow(non_snake_case)] + use proofman_common as common; pub use proofman_macros::trace; pub use proofman_macros::values; use std::fmt; +use rayon::prelude::*; + #[allow(dead_code)] type FieldExtension = [F; 3]; @@ -27,21 +32,41 @@ pub const U_8_AIR_AIR_IDS: &[usize] = &[2]; //PUBLICS use serde::Deserialize; use serde::Serialize; -#[derive(Default, Debug, Serialize, Deserialize)] +use serde_arrays; + + +fn default_array_rom_root() -> [u64; 4] { + [0; 4] +} + + +#[derive(Debug, Serialize, Deserialize)] pub struct BuildPublics { - #[serde(default)] + #[serde(default)] pub module: u64, - #[serde(default)] + #[serde(default)] pub in1: u64, - #[serde(default)] + #[serde(default)] pub in2: u64, - #[serde(default)] + #[serde(default)] pub out: u64, - #[serde(default)] + #[serde(default = "default_array_rom_root", with = "serde_arrays")] pub rom_root: [u64; 4], } +impl Default for BuildPublics { + fn default() -> Self { + Self { + module: 0, + in1: 0, + in2: 0, + out: 0, + rom_root: [0; 4], + } + } +} + values!(BuildPublicValues { module: F, in1: F, in2: F, out: F, rom_root: [F; 4], }); @@ -50,13 +75,25 @@ values!(BuildProofValues { value1: F, value2: F, }); +trace!(FibonacciSquareFixed { + L1: F, __L1__: F, +}, 0, 0, 262144 ); + trace!(FibonacciSquareTrace { a: F, b: F, }, 0, 0, 262144 ); +trace!(ModuleFixed { + SEGMENT_LN: F, __L1__: F, +}, 0, 1, 65536 ); + trace!(ModuleTrace { x: F, q: F, x_mod: F, -}, 0, 1, 16384 ); +}, 0, 1, 65536 ); + +trace!(U8AirFixed { + U8: F, __L1__: F, +}, 0, 2, 256 ); trace!(U8AirTrace { mul: F, diff --git a/hints/src/global_hints.rs b/hints/src/global_hints.rs index 38c4978f2..1c2488b3f 100644 --- a/hints/src/global_hints.rs +++ b/hints/src/global_hints.rs @@ -24,7 +24,7 @@ pub fn aggregate_airgroupvals(pctx: Arc>) -> Vec> airgroupvalues.push(values); } - for (_, air_instance) in pctx.air_instance_repo.air_instances.write().unwrap().iter() { + for (_, air_instance) in pctx.air_instance_repo.air_instances.read().unwrap().iter() { for (idx, agg_type) in pctx.global_info.agg_types[air_instance.airgroup_id].iter().enumerate() { let mut acc = ExtensionField { value: [ @@ -72,7 +72,7 @@ pub fn aggregate_airgroupvals(pctx: Arc>) -> Vec> fn get_global_hint_f( pctx: Option>>, - sctx: Arc, + sctx: Arc>, hint_id: u64, hint_field_name: &str, print_expression: bool, @@ -132,7 +132,7 @@ fn get_global_hint_f( hint_field_values } pub fn get_hint_field_constant_gc( - sctx: Arc, + sctx: Arc>, hint_id: u64, hint_field_name: &str, print_expression: bool, @@ -151,7 +151,7 @@ pub fn get_hint_field_constant_gc( } pub fn get_hint_field_gc_constant_a( - sctx: Arc, + sctx: Arc>, hint_id: u64, hint_field_name: &str, print_expression: bool, @@ -174,7 +174,7 @@ pub fn get_hint_field_gc_constant_a( } pub fn get_hint_field_constant_gc_m( - sctx: Arc, + sctx: Arc>, hint_id: u64, hint_field_name: &str, print_expression: bool, @@ -203,7 +203,7 @@ pub fn get_hint_field_constant_gc_m( pub fn get_hint_field_gc( pctx: Arc>, - sctx: Arc, + sctx: Arc>, hint_id: u64, hint_field_name: &str, print_expression: bool, @@ -223,7 +223,7 @@ pub fn get_hint_field_gc( pub fn get_hint_field_gc_a( pctx: Arc>, - sctx: Arc, + sctx: Arc>, hint_id: u64, hint_field_name: &str, print_expression: bool, @@ -247,7 +247,7 @@ pub fn get_hint_field_gc_a( pub fn get_hint_field_gc_m( pctx: Arc>, - sctx: Arc, + sctx: Arc>, hint_id: u64, hint_field_name: &str, print_expression: bool, @@ -276,7 +276,7 @@ pub fn get_hint_field_gc_m( pub fn set_hint_field_gc( pctx: Arc>, - sctx: Arc, + sctx: Arc>, hint_id: u64, hint_field_name: &str, value: HintFieldOutput, diff --git a/hints/src/hints.rs b/hints/src/hints.rs index 396ee6441..6d1701e30 100644 --- a/hints/src/hints.rs +++ b/hints/src/hints.rs @@ -750,7 +750,7 @@ pub fn get_hint_ids_by_name(p_expressions_bin: *mut std::os::raw::c_void, name: #[allow(clippy::too_many_arguments)] pub fn mul_hint_fields( - sctx: &SetupCtx, + sctx: &SetupCtx, pctx: &ProofCtx, air_instance: &mut AirInstance, hint_id: usize, @@ -792,7 +792,7 @@ pub fn mul_hint_fields( #[allow(clippy::too_many_arguments)] pub fn acc_hint_field( - sctx: &SetupCtx, + sctx: &SetupCtx, pctx: &ProofCtx, air_instance: &mut AirInstance, hint_id: usize, @@ -837,7 +837,7 @@ pub fn acc_hint_field( #[allow(clippy::too_many_arguments)] pub fn acc_mul_hint_fields( - sctx: &SetupCtx, + sctx: &SetupCtx, pctx: &ProofCtx, air_instance: &mut AirInstance, hint_id: usize, @@ -888,7 +888,7 @@ pub fn acc_mul_hint_fields( #[allow(clippy::too_many_arguments)] pub fn update_airgroupvalue( - sctx: &SetupCtx, + sctx: &SetupCtx, pctx: &ProofCtx, air_instance: &mut AirInstance, hint_id: usize, @@ -932,11 +932,11 @@ pub fn update_airgroupvalue( #[allow(clippy::too_many_arguments)] fn get_hint_f( - sctx: &SetupCtx, + sctx: &SetupCtx, pctx: Option<&ProofCtx>, airgroup_id: usize, air_id: usize, - air_instance: Option<&mut AirInstance>, + air_instance: Option<&AirInstance>, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, @@ -999,9 +999,9 @@ fn get_hint_f( hint_field_values } pub fn get_hint_field( - sctx: &SetupCtx, + sctx: &SetupCtx, pctx: &ProofCtx, - air_instance: &mut AirInstance, + air_instance: &AirInstance, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, @@ -1029,9 +1029,9 @@ pub fn get_hint_field( } pub fn get_hint_field_a( - sctx: &SetupCtx, + sctx: &SetupCtx, pctx: &ProofCtx, - air_instance: &mut AirInstance, + air_instance: &AirInstance, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, @@ -1063,9 +1063,9 @@ pub fn get_hint_field_a( } pub fn get_hint_field_m( - sctx: &SetupCtx, + sctx: &SetupCtx, pctx: &ProofCtx, - air_instance: &mut AirInstance, + air_instance: &AirInstance, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, @@ -1102,7 +1102,7 @@ pub fn get_hint_field_m( } pub fn get_hint_field_constant( - sctx: &SetupCtx, + sctx: &SetupCtx, airgroup_id: usize, air_id: usize, hint_id: usize, @@ -1126,7 +1126,7 @@ pub fn get_hint_field_constant( } pub fn get_hint_field_constant_a( - sctx: &SetupCtx, + sctx: &SetupCtx, airgroup_id: usize, air_id: usize, hint_id: usize, @@ -1154,7 +1154,7 @@ pub fn get_hint_field_constant_a( } pub fn get_hint_field_constant_m( - sctx: &SetupCtx, + sctx: &SetupCtx, airgroup_id: usize, air_id: usize, hint_id: usize, @@ -1187,7 +1187,7 @@ pub fn get_hint_field_constant_m( } pub fn set_hint_field( - sctx: &SetupCtx, + sctx: &SetupCtx, air_instance: &mut AirInstance, hint_id: u64, hint_field_name: &str, @@ -1221,7 +1221,7 @@ pub fn set_hint_field( } pub fn set_hint_field_val( - sctx: &SetupCtx, + sctx: &SetupCtx, air_instance: &mut AirInstance, hint_id: u64, hint_field_name: &str, @@ -1263,7 +1263,7 @@ pub fn set_hint_field_val( set_hint_field_c((&setup.p_setup).into(), (&steps_params).into(), values_ptr, hint_id, hint_field_name); } -pub fn print_row(sctx: &SetupCtx, air_instance: &AirInstance, stage: u64, row: u64) { +pub fn print_row(sctx: &SetupCtx, air_instance: &AirInstance, stage: u64, row: u64) { let setup = sctx.get_setup(air_instance.airgroup_id, air_instance.air_id); let buffer = match stage == 1 { diff --git a/macros/Cargo.toml b/macros/Cargo.toml index dbc57525f..6d7f3f047 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -10,3 +10,8 @@ proc-macro = true syn = { version = "2", features = ["full"] } quote = "1" proc-macro2 = "1" +rayon = "1.10" + +[features] +default = [] +debug = [] diff --git a/macros/src/lib.rs b/macros/src/lib.rs index e8e9c4013..eea75c942 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -55,6 +55,8 @@ fn trace_impl(input: TokenStream2) -> Result { // Generate trace struct let trace_struct = quote! { + use rayon::prelude::*; + pub struct #trace_struct_name<#generics> { pub buffer: Vec<#row_struct_name<#generics>>, pub num_rows: usize, @@ -64,7 +66,7 @@ fn trace_impl(input: TokenStream2) -> Result { pub commit_id: Option, } - impl<#generics: Default + Clone + Copy> #trace_struct_name<#generics> { + impl<#generics: Default + Clone + Copy + Send> #trace_struct_name<#generics> { pub const NUM_ROWS: usize = #num_rows; pub const AIRGROUP_ID: usize = #airgroup_id; pub const AIR_ID: usize = #air_id; @@ -77,12 +79,23 @@ fn trace_impl(input: TokenStream2) -> Result { assert!(num_rows >= 2); assert!(num_rows & (num_rows - 1) == 0); - let mut buff_uninit: Vec>> = Vec::with_capacity(num_rows); - - unsafe { - buff_uninit.set_len(num_rows); - } - let buffer: Vec<#row_struct_name::<#generics>> = unsafe { std::mem::transmute(buff_uninit) }; + let buffer: Vec<#row_struct_name::<#generics>> = if cfg!(feature = "debug") { + let mut buffer_u64 = vec![u64::MAX - 1; num_rows * #row_struct_name::<#generics>::ROW_SIZE]; + + // Convert safely by properly managing size & alignment + let ptr = buffer_u64.as_mut_ptr(); + let len = buffer_u64.len() / #row_struct_name::<#generics>::ROW_SIZE; + let cap = buffer_u64.capacity() / #row_struct_name::<#generics>::ROW_SIZE; + std::mem::forget(buffer_u64); + + unsafe { Vec::from_raw_parts(ptr as *mut #row_struct_name<#generics>, len, cap) } + } else { + let mut buff_uninit: Vec>> = Vec::with_capacity(num_rows); + unsafe { + buff_uninit.set_len(num_rows); + } + unsafe { std::mem::transmute(buff_uninit) } + }; #trace_struct_name { buffer, @@ -99,7 +112,8 @@ fn trace_impl(input: TokenStream2) -> Result { assert!(num_rows >= 2); assert!(num_rows & (num_rows - 1) == 0); - let buffer = vec![#row_struct_name::<#generics>::default(); num_rows]; + let buffer: Vec<#row_struct_name<#generics>> = vec![#row_struct_name::<#generics>::default(); num_rows]; + #trace_struct_name { buffer, @@ -126,6 +140,47 @@ fn trace_impl(input: TokenStream2) -> Result { } } + pub fn from_slice( + slice: &[#generics], + ) -> Self { + let num_rows = Self::NUM_ROWS; + + unsafe { + // Create a mutable slice from the raw pointer + let buffer: &mut [#row_struct_name<#generics>] = std::slice::from_raw_parts_mut( + slice.as_ptr() as *mut #row_struct_name<#generics>, + num_rows + ); + + // Convert the slice into a Vec without taking ownership (caller still owns the memory) + let buffer_vec = buffer.to_vec(); // This creates a new Vec, without modifying the original memory + + Self { + buffer: buffer_vec, + num_rows, + row_size: #row_struct_name::<#generics>::ROW_SIZE, + airgroup_id: Self::AIRGROUP_ID, + air_id: Self::AIR_ID, + commit_id: #commit_id, + } + } + } + + /// Returns parallel mutable iterators to access the buffer. + /// + /// # Arguments + /// * `n` - The number of segments to divide the buffer into. Must be a power of two and <= `NUM_ROWS`. + /// + /// # Panics + /// Panics if `n` is not a power of two or if `n > NUM_ROWS`. + pub fn par_iter_mut_chunks(&mut self, n: usize) -> impl rayon::iter::IndexedParallelIterator]> { + assert!(n > 0 && (n & (n - 1)) == 0, "n must be a power of two"); + assert!(n <= self.num_rows, "n must be less than or equal to NUM_ROWS"); + let chunk_size = self.num_rows / n; + assert!(chunk_size > 0, "Chunk size must be greater than zero"); + self.buffer.par_chunks_mut(chunk_size) + } + pub fn num_rows(&self) -> usize { self.num_rows } @@ -318,7 +373,7 @@ fn values_impl(input: TokenStream2) -> Result { let row_struct = quote! { #[repr(C)] - #[derive(Debug, Clone, Copy, Default)] + #[derive(Debug, Clone, Copy)] pub struct #row_struct_name<#generics> { #(#field_definitions)* } diff --git a/pil2-components/lib/std/pil/std_connection.pil b/pil2-components/lib/std/pil/std_connection.pil index bfa8dc5b2..2238a1177 100644 --- a/pil2-components/lib/std/pil/std_connection.pil +++ b/pil2-components/lib/std/pil/std_connection.pil @@ -108,7 +108,7 @@ function connection_init(const int opid, const expr cols[], int default_frame_si int cols_num = 0; expr map_cols[cols_count]; - col fixed CONN[cols_count]; + col fixed `CONN_${opid}`[cols_count]; // 1st dimension: The number of connections can be as large as one wants. Until we have dynamic arrays, we fix the size to cols·rows // 2nd dimension: @@ -158,7 +158,7 @@ function connection_init(const int opid, const expr cols[], int default_frame_si // Initialize polynomial CONN[0](X) as X and CONN[i](X) as k_i·X, for i >= 0 for (int j = 0; j < N; j++) { - connid.CONN[i][j] = k * conn.ID[j]; + connid.`CONN_${opid}`[i][j] = k * conn.ID[j]; } } @@ -447,15 +447,15 @@ function connection_connect(const int opid) { const int col2_id_clust = conn_cluster[j+1][0]; const int row2_id_clust = conn_cluster[j+1][1]; - const int tmp = CONN[col1_id_clust][row1_id_clust]; - CONN[col1_id_clust][row1_id_clust] = CONN[col2_id_clust][row2_id_clust]; - CONN[col2_id_clust][row2_id_clust] = tmp; + const int tmp = air.std.connect.`id${opid}`.`CONN_${opid}`[col1_id_clust][row1_id_clust]; + air.std.connect.`id${opid}`.`CONN_${opid}`[col1_id_clust][row1_id_clust] = air.std.connect.`id${opid}`.`CONN_${opid}`[col2_id_clust][row2_id_clust]; + air.std.connect.`id${opid}`.`CONN_${opid}`[col2_id_clust][row2_id_clust] = tmp; } } // Send to the bus the permuted columns for (int i = 0; i < cols_num; i++) { - permutation_proves(opid, [map_cols[i], CONN[i]], bus_type: bus_type, name: PIOP_NAME_CONNECTION); + permutation_proves(opid, [map_cols[i], air.std.connect.`id${opid}`.`CONN_${opid}`[i]], bus_type: bus_type, name: PIOP_NAME_CONNECTION); } // Mark the connection as closed diff --git a/pil2-components/lib/std/pil/std_range_check.pil b/pil2-components/lib/std/pil/std_range_check.pil index 45c5b4520..f1ae472b5 100644 --- a/pil2-components/lib/std/pil/std_range_check.pil +++ b/pil2-components/lib/std/pil/std_range_check.pil @@ -111,6 +111,7 @@ airtemplate SpecifiedRanges(const int N, const int opids[], const int opids_coun @specified_ranges{num_rows: N}; col witness mul[opids_count]; + col fixed RANGE[opids_count]; for (int j = 0; j < opids_count; j++) { int opid = opids[j]; @@ -124,8 +125,8 @@ airtemplate SpecifiedRanges(const int N, const int opids[], const int opids_coun error(`The range [min,max]=[${min},${max}] is too big, the maximum range length is ${N}`); } - col fixed RANGE = [min..max-1,max...]; - lookup_proves(opid, [RANGE], mul[j], PIOP_NAME_RANGE_CHECK); + RANGE[j] = [min..max-1,max...]; + lookup_proves(opid, [RANGE[j]], mul[j], PIOP_NAME_RANGE_CHECK); } } diff --git a/pil2-components/lib/std/rs/Cargo.toml b/pil2-components/lib/std/rs/Cargo.toml index aa58e2cd4..07a7a824b 100644 --- a/pil2-components/lib/std/rs/Cargo.toml +++ b/pil2-components/lib/std/rs/Cargo.toml @@ -14,3 +14,4 @@ p3-goldilocks.workspace = true p3-field.workspace = true rayon.workspace = true witness.workspace = true +colored.workspace = true diff --git a/pil2-components/lib/std/rs/src/common.rs b/pil2-components/lib/std/rs/src/common.rs index 7d7884fc0..c475ee0c4 100644 --- a/pil2-components/lib/std/rs/src/common.rs +++ b/pil2-components/lib/std/rs/src/common.rs @@ -2,32 +2,25 @@ use std::sync::Arc; use p3_field::PrimeField; use num_traits::ToPrimitive; -use proofman_common::{AirInstance, ProofCtx, SetupCtx}; +use proofman_common::{ProofCtx, SetupCtx}; use proofman_hints::{ get_hint_field_constant_gc, get_hint_field_constant, get_hint_field_constant_a, HintFieldOptions, HintFieldOutput, HintFieldValue, }; -pub trait AirComponent { +pub trait AirComponent { const MY_NAME: &'static str; - fn new(pctx: Arc>, sctx: Arc, airgroup_id: Option, air_id: Option) - -> Arc; - - fn debug_mode( - &self, - _pctx: &ProofCtx, - _sctx: &SetupCtx, - _air_instance: &mut AirInstance, - _air_instance_id: usize, - _num_rows: usize, - _debug_data_hints: Vec, - ) { - } + fn new( + pctx: Arc>, + sctx: Arc>, + airgroup_id: Option, + air_id: Option, + ) -> Arc; } // Helper to extract hint fields -pub fn get_global_hint_field_constant_as(sctx: Arc, hint_id: u64, field_name: &str) -> T +pub fn get_global_hint_field_constant_as(sctx: Arc>, hint_id: u64, field_name: &str) -> T where T: TryFrom, T::Error: std::fmt::Debug, @@ -48,7 +41,7 @@ where } pub fn get_hint_field_constant_as_field( - sctx: &SetupCtx, + sctx: &SetupCtx, airgroup_id: usize, air_id: usize, hint_id: usize, @@ -62,7 +55,7 @@ pub fn get_hint_field_constant_as_field( } pub fn get_hint_field_constant_a_as_string( - sctx: &SetupCtx, + sctx: &SetupCtx, airgroup_id: usize, air_id: usize, hint_id: usize, @@ -84,7 +77,7 @@ pub fn get_hint_field_constant_a_as_string( } pub fn get_hint_field_constant_as_string( - sctx: &SetupCtx, + sctx: &SetupCtx, airgroup_id: usize, air_id: usize, hint_id: usize, diff --git a/pil2-components/lib/std/rs/src/debug.rs b/pil2-components/lib/std/rs/src/debug.rs index ad13d6114..9a3aff695 100644 --- a/pil2-components/lib/std/rs/src/debug.rs +++ b/pil2-components/lib/std/rs/src/debug.rs @@ -1,16 +1,23 @@ use std::{ collections::HashMap, fs::{self, File}, + hash::{DefaultHasher, Hasher, Hash}, io::{self, Write}, path::{Path, PathBuf}, - sync::Mutex, }; use p3_field::PrimeField; use proofman_common::ProofCtx; use proofman_hints::{format_vec, HintFieldOutput}; -pub type DebugData = Mutex>, BusValue>>>; // opid -> val -> BusValue +use num_bigint::BigUint; +use num_traits::Zero; + +use colored::*; + +pub type DebugData = HashMap>, BusValue>>; // opid -> val -> BusValue + +pub type DebugDataFast = HashMap; // opid -> sharedDataFast #[derive(Debug)] pub struct BusValue { @@ -25,6 +32,14 @@ struct SharedData { num_assumes: F, } +#[derive(Clone, Debug)] +pub struct SharedDataFast { + pub num_proves: BigUint, + pub num_assumes: BigUint, + pub num_proves_global: Vec, + pub num_assumes_global: Vec, +} + type AirGroupMap = HashMap; type AirMap = HashMap; @@ -43,9 +58,63 @@ struct InstanceData { row_assumes: Vec, } +#[allow(clippy::too_many_arguments)] +pub fn update_debug_data_fast( + debug_data_fast: &mut DebugDataFast, + opid: F, + val: Vec>, + proves: bool, + times: F, + is_global: bool, +) { + let bus_opid_times = debug_data_fast.entry(opid).or_insert_with(|| SharedDataFast { + num_assumes_global: Vec::new(), + num_proves_global: Vec::new(), + num_proves: BigUint::zero(), + num_assumes: BigUint::zero(), + }); + + let mut values = Vec::new(); + for value in val.iter() { + match value { + HintFieldOutput::Field(f) => values.push(*f), + HintFieldOutput::FieldExtended(ef) => { + values.push(ef.value[0]); + values.push(ef.value[1]); + values.push(ef.value[2]); + } + } + } + + let mut hasher = DefaultHasher::new(); + values.hash(&mut hasher); + + let hash_value = BigUint::from(hasher.finish()); + + if is_global { + if proves { + // Check if bus op id times num proves global contains value + if bus_opid_times.num_proves_global.contains(&hash_value) { + return; + } + bus_opid_times.num_proves_global.push(hash_value * times.as_canonical_biguint()); + } else { + if bus_opid_times.num_assumes_global.contains(&hash_value) { + return; + } + bus_opid_times.num_assumes_global.push(hash_value); + } + } else if proves { + bus_opid_times.num_proves += hash_value * times.as_canonical_biguint(); + } else { + assert!(times.is_one(), "The selector value is invalid: expected 1, but received {:?}.", times); + bus_opid_times.num_assumes += hash_value; + } +} + #[allow(clippy::too_many_arguments)] pub fn update_debug_data( - debug_data: &DebugData, + debug_data: &mut DebugData, name_piop: &str, name_expr: &[String], opid: F, @@ -58,9 +127,7 @@ pub fn update_debug_data( times: F, is_global: bool, ) { - let mut bus = debug_data.lock().expect("Bus values missing"); - - let bus_opid = bus.entry(opid).or_default(); + let bus_opid = debug_data.entry(opid).or_default(); let bus_val = bus_opid.entry(val).or_insert_with(|| BusValue { shared_data: SharedData { direct_was_called: false, num_proves: F::zero(), num_assumes: F::zero() }, @@ -99,18 +166,77 @@ pub fn update_debug_data( } } +pub fn check_invalid_opids( + _pctx: &ProofCtx, + name: &str, + debugs_data_fasts: &mut [DebugDataFast], +) -> Vec { + let mut debug_data_fast = HashMap::new(); + + let mut global_assumes = Vec::new(); + let mut global_proves = Vec::new(); + for map in debugs_data_fasts { + for (opid, bus) in map.iter() { + if debug_data_fast.contains_key(opid) { + let bus_fast: &mut SharedDataFast = debug_data_fast.get_mut(opid).unwrap(); + for assume_global in bus.num_assumes_global.iter() { + if global_assumes.contains(assume_global) { + continue; + } + global_assumes.push(assume_global.clone()); + bus_fast.num_assumes += assume_global; + } + for prove_global in bus.num_proves_global.iter() { + if global_proves.contains(prove_global) { + continue; + } + global_proves.push(prove_global.clone()); + bus_fast.num_proves += prove_global; + } + + bus_fast.num_proves += bus.num_proves.clone(); + bus_fast.num_assumes += bus.num_assumes.clone(); + } else { + debug_data_fast.insert(*opid, bus.clone()); + } + } + } + + // TODO: SINCRONIZATION IN DISTRIBUTED MODE + + let mut invalid_opids = Vec::new(); + + // Check if there are any invalid opids + + for (opid, bus) in debug_data_fast.iter_mut() { + if bus.num_proves != bus.num_assumes { + invalid_opids.push(*opid); + } + } + + if !invalid_opids.is_empty() { + log::error!( + "{}: ··· {}", + name, + format!("\u{2717} The following opids does not match {:?}", invalid_opids).bright_red().bold() + ); + } else { + log::info!("{}: ··· {}", name, "\u{2713} All bus values match.".bright_green().bold()); + } + + invalid_opids +} pub fn print_debug_info( pctx: &ProofCtx, name: &str, max_values_to_print: usize, print_to_file: bool, - debug_data: &DebugData, + debug_data: &mut DebugData, ) { let mut file_path = PathBuf::new(); let mut output: Box = Box::new(io::stdout()); let mut there_are_errors = false; - let mut bus_vals = debug_data.lock().expect("Bus values missing"); - for (opid, bus) in bus_vals.iter_mut() { + for (opid, bus) in debug_data.iter_mut() { if bus.iter().any(|(_, v)| v.shared_data.num_proves != v.shared_data.num_assumes) { if !there_are_errors { // Print to a file if requested @@ -204,6 +330,10 @@ pub fn print_debug_info( } } + if !there_are_errors { + log::info!("{}: ··· {}", name, "\u{2713} All bus values match.".bright_green().bold()); + } + fn print_diffs( pctx: &ProofCtx, val: &[HintFieldOutput], diff --git a/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs b/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs index f16979a4b..8a24700e2 100644 --- a/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs +++ b/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs @@ -33,7 +33,7 @@ impl AirComponent for SpecifiedRanges { fn new( _pctx: Arc>, - _sctx: Arc, + _sctx: Arc>, airgroup_id: Option, air_id: Option, ) -> Arc { @@ -65,7 +65,7 @@ impl SpecifiedRanges { } } - pub fn drain_inputs(&self, pctx: Arc>, sctx: Arc) { + pub fn drain_inputs(&self, pctx: Arc>, sctx: Arc>) { let mut inputs = self.inputs.lock().unwrap(); let drained_inputs = inputs.drain(..).collect(); @@ -85,13 +85,8 @@ impl SpecifiedRanges { }) .collect::>>(); - let (instance_found, global_idx) = pctx.dctx_find_instance(self.airgroup_id, self.air_id); - - let (is_mine, global_idx) = if instance_found { - (pctx.dctx_is_my_instance(global_idx), global_idx) - } else { - pctx.dctx_add_instance(self.airgroup_id, self.air_id, pctx.get_weight(self.airgroup_id, self.air_id)) - }; + let (_, global_idx) = pctx.dctx_find_instance(self.airgroup_id, self.air_id); + let is_mine = pctx.dctx_is_my_instance(global_idx); pctx.dctx_distribute_multiplicities(&mut multiplicities, global_idx); @@ -159,7 +154,7 @@ impl SpecifiedRanges { } impl WitnessComponent for SpecifiedRanges { - fn start_proof(&self, pctx: Arc>, sctx: Arc) { + fn start_proof(&self, pctx: Arc>, sctx: Arc>) { // Obtain info from the mul hints let setup = sctx.get_setup(self.airgroup_id, self.air_id); let specified_hints = get_hint_ids_by_name(setup.p_setup.p_expressions_bin, "specified_ranges"); @@ -273,13 +268,13 @@ impl WitnessComponent for SpecifiedRanges { let buffer = create_buffer_fast(buffer_size as usize); // Add a new air instance. Since Specified Ranges is a table, only this air instance is needed - let mut air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer)); + let air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer)); let mut mul_columns_guard = self.mul_columns.lock().unwrap(); for hint in hints_guard[1..].iter() { mul_columns_guard.push(get_hint_field::( &sctx, &pctx, - &mut air_instance, + &air_instance, hint.to_usize().unwrap(), "reference", HintFieldOptions::dest_with_zeros(), @@ -306,7 +301,19 @@ impl WitnessComponent for SpecifiedRanges { *self.num_rows.lock().unwrap() = num_rows.as_canonical_biguint().to_usize().unwrap(); } - fn calculate_witness(&self, stage: u32, pctx: Arc>, sctx: Arc) { + fn execute(&self, pctx: Arc>) { + let (instance_found, _global_idx) = pctx.dctx_find_instance(self.airgroup_id, self.air_id); + + if !instance_found { + pctx.dctx_add_instance_no_assign( + self.airgroup_id, + self.air_id, + pctx.get_weight(self.airgroup_id, self.air_id), + ); + } + } + + fn calculate_witness(&self, stage: u32, pctx: Arc>, sctx: Arc>) { if stage == 1 { Self::drain_inputs(self, pctx, sctx); } diff --git a/pil2-components/lib/std/rs/src/range_check/std_range_check.rs b/pil2-components/lib/std/rs/src/range_check/std_range_check.rs index 4b3cfaa6a..19f963d0f 100644 --- a/pil2-components/lib/std/rs/src/range_check/std_range_check.rs +++ b/pil2-components/lib/std/rs/src/range_check/std_range_check.rs @@ -49,7 +49,7 @@ pub struct StdRangeCheck { impl StdRangeCheck { const _MY_NAME: &'static str = "STD Range Check"; - pub fn new(pctx: Arc>, sctx: Arc) -> Arc { + pub fn new(pctx: Arc>, sctx: Arc>) -> Arc { // Find which range check related AIRs need to be instantiated let u8air_hint = get_hint_ids_by_name(sctx.get_global_bin(), "u8air"); let u16air_hint = get_hint_ids_by_name(sctx.get_global_bin(), "u16air"); @@ -81,7 +81,7 @@ impl StdRangeCheck { return std_range_check; // Helper function to instantiate AIRs - fn create_air(pctx: Arc>, sctx: Arc, hints: &[u64]) -> Option> + fn create_air(pctx: Arc>, sctx: Arc>, hints: &[u64]) -> Option> where T: AirComponent, { @@ -94,7 +94,7 @@ impl StdRangeCheck { } } - fn register_ranges(&self, sctx: &SetupCtx, airgroup_id: usize, air_id: usize) { + fn register_ranges(&self, sctx: &SetupCtx, airgroup_id: usize, air_id: usize) { let setup = sctx.get_setup(airgroup_id, air_id); // Obtain info from the range hints @@ -252,7 +252,7 @@ impl StdRangeCheck { } } - pub fn drain_inputs(&self, pctx: Arc>, sctx: Arc) { + pub fn drain_inputs(&self, pctx: Arc>, sctx: Arc>) { if let Some(u8air) = self.u8air.as_ref() { u8air.drain_inputs(pctx.clone(), sctx.clone()); } diff --git a/pil2-components/lib/std/rs/src/range_check/u16air.rs b/pil2-components/lib/std/rs/src/range_check/u16air.rs index 8ab33689e..53e7efbc3 100644 --- a/pil2-components/lib/std/rs/src/range_check/u16air.rs +++ b/pil2-components/lib/std/rs/src/range_check/u16air.rs @@ -29,7 +29,7 @@ impl AirComponent for U16Air { fn new( _pctx: Arc>, - _sctx: Arc, + _sctx: Arc>, airgroup_id: Option, air_id: Option, ) -> Arc { @@ -59,7 +59,7 @@ impl U16Air { } } - pub fn drain_inputs(&self, pctx: Arc>, sctx: Arc) { + pub fn drain_inputs(&self, pctx: Arc>, sctx: Arc>) { let mut inputs = self.inputs.lock().unwrap(); let drained_inputs = inputs.drain(..).collect(); @@ -73,13 +73,8 @@ impl U16Air { _ => panic!("Multiplicities must be a column"), }; - let (instance_found, global_idx) = pctx.dctx_find_instance(self.airgroup_id, self.air_id); - - let (is_mine, global_idx) = if instance_found { - (pctx.dctx_is_my_instance(global_idx), global_idx) - } else { - pctx.dctx_add_instance(self.airgroup_id, self.air_id, pctx.get_weight(self.airgroup_id, self.air_id)) - }; + let (_, global_idx) = pctx.dctx_find_instance(self.airgroup_id, self.air_id); + let is_mine = pctx.dctx_is_my_instance(global_idx); pctx.dctx_distribute_multiplicity(&mut multiplicity, global_idx); @@ -128,7 +123,7 @@ impl U16Air { } impl WitnessComponent for U16Air { - fn start_proof(&self, pctx: Arc>, sctx: Arc) { + fn start_proof(&self, pctx: Arc>, sctx: Arc>) { // Obtain info from the mul hints let setup = sctx.get_setup(self.airgroup_id, self.air_id); let u16air_hints = get_hint_ids_by_name(setup.p_setup.p_expressions_bin, "u16air"); @@ -143,19 +138,31 @@ impl WitnessComponent for U16Air { let buffer = create_buffer_fast(buffer_size); // Add a new air instance. Since U16Air is a table, only this air instance is needed - let mut air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer)); + let air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer)); *self.mul_column.lock().unwrap() = get_hint_field::( &sctx, &pctx, - &mut air_instance, + &air_instance, u16air_hints[0] as usize, "reference", HintFieldOptions::dest_with_zeros(), ); } - fn calculate_witness(&self, stage: u32, pctx: Arc>, sctx: Arc) { + fn execute(&self, pctx: Arc>) { + let (instance_found, _global_idx) = pctx.dctx_find_instance(self.airgroup_id, self.air_id); + + if !instance_found { + pctx.dctx_add_instance_no_assign( + self.airgroup_id, + self.air_id, + pctx.get_weight(self.airgroup_id, self.air_id), + ); + } + } + + fn calculate_witness(&self, stage: u32, pctx: Arc>, sctx: Arc>) { if stage == 1 { Self::drain_inputs(self, pctx, sctx); } diff --git a/pil2-components/lib/std/rs/src/range_check/u8air.rs b/pil2-components/lib/std/rs/src/range_check/u8air.rs index 8d5a7089d..0d9df4bd4 100644 --- a/pil2-components/lib/std/rs/src/range_check/u8air.rs +++ b/pil2-components/lib/std/rs/src/range_check/u8air.rs @@ -29,7 +29,7 @@ impl AirComponent for U8Air { fn new( _pctx: Arc>, - _sctx: Arc, + _sctx: Arc>, airgroup_id: Option, air_id: Option, ) -> Arc { @@ -59,7 +59,7 @@ impl U8Air { } } - pub fn drain_inputs(&self, pctx: Arc>, sctx: Arc) { + pub fn drain_inputs(&self, pctx: Arc>, sctx: Arc>) { let mut inputs = self.inputs.lock().unwrap(); let drained_inputs = inputs.drain(..).collect(); @@ -73,13 +73,8 @@ impl U8Air { _ => panic!("Multiplicities must be a column"), }; - let (instance_found, global_idx) = pctx.dctx_find_instance(self.airgroup_id, self.air_id); - - let (is_mine, global_idx) = if instance_found { - (pctx.dctx_is_my_instance(global_idx), global_idx) - } else { - pctx.dctx_add_instance(self.airgroup_id, self.air_id, pctx.get_weight(self.airgroup_id, self.air_id)) - }; + let (_, global_idx) = pctx.dctx_find_instance(self.airgroup_id, self.air_id); + let is_mine = pctx.dctx_is_my_instance(global_idx); pctx.dctx_distribute_multiplicity(&mut multiplicity, global_idx); @@ -127,7 +122,7 @@ impl U8Air { } impl WitnessComponent for U8Air { - fn start_proof(&self, pctx: Arc>, sctx: Arc) { + fn start_proof(&self, pctx: Arc>, sctx: Arc>) { // Obtain info from the mul hints let setup = sctx.get_setup(self.airgroup_id, self.air_id); let u8air_hints = get_hint_ids_by_name(setup.p_setup.p_expressions_bin, "u8air"); @@ -142,19 +137,31 @@ impl WitnessComponent for U8Air { let buffer = create_buffer_fast(buffer_size); // Add a new air instance. Since U8Air is a table, only this air instance is needed - let mut air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer)); + let air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer)); *self.mul_column.lock().unwrap() = get_hint_field::( &sctx, &pctx, - &mut air_instance, + &air_instance, u8air_hints[0] as usize, "reference", HintFieldOptions::dest_with_zeros(), ); } - fn calculate_witness(&self, stage: u32, pctx: Arc>, sctx: Arc) { + fn execute(&self, pctx: Arc>) { + let (instance_found, _global_idx) = pctx.dctx_find_instance(self.airgroup_id, self.air_id); + + if !instance_found { + pctx.dctx_add_instance_no_assign( + self.airgroup_id, + self.air_id, + pctx.get_weight(self.airgroup_id, self.air_id), + ); + } + } + + fn calculate_witness(&self, stage: u32, pctx: Arc>, sctx: Arc>) { if stage == 1 { Self::drain_inputs(self, pctx, sctx); } diff --git a/pil2-components/lib/std/rs/src/std.rs b/pil2-components/lib/std/rs/src/std.rs index 99b849310..8a4c4674f 100644 --- a/pil2-components/lib/std/rs/src/std.rs +++ b/pil2-components/lib/std/rs/src/std.rs @@ -10,7 +10,7 @@ use crate::{AirComponent, StdProd, StdRangeCheck, RangeCheckAir, StdSum}; pub struct Std { pub pctx: Arc>, - pub sctx: Arc, + pub sctx: Arc>, pub range_check: Arc>, pub std_prod: Arc>, pub std_sum: Arc>, @@ -39,22 +39,22 @@ impl Std { std_sum: Arc>, range_check: Arc>, ) { - wcm.register_component(std_prod.clone()); - wcm.register_component(std_sum.clone()); + wcm.register_component_std(std_prod.clone()); + wcm.register_component_std(std_sum.clone()); if range_check.u8air.is_some() { - wcm.register_component(range_check.u8air.clone().unwrap()); + wcm.register_component_std(range_check.u8air.clone().unwrap()); } if range_check.u16air.is_some() { - wcm.register_component(range_check.u16air.clone().unwrap()); + wcm.register_component_std(range_check.u16air.clone().unwrap()); } if range_check.specified_ranges.is_some() { - wcm.register_component(range_check.specified_ranges.clone().unwrap()); + wcm.register_component_std(range_check.specified_ranges.clone().unwrap()); } - wcm.register_component(range_check.clone()); + wcm.register_component_std(range_check.clone()); } // Gets the range for the range check. diff --git a/pil2-components/lib/std/rs/src/std_prod.rs b/pil2-components/lib/std/rs/src/std_prod.rs index 88c5d19b7..3c27e7c37 100644 --- a/pil2-components/lib/std/rs/src/std_prod.rs +++ b/pil2-components/lib/std/rs/src/std_prod.rs @@ -3,34 +3,37 @@ use std::{ sync::{Arc, Mutex}, }; +use rayon::prelude::*; + use num_traits::ToPrimitive; use p3_field::PrimeField; +use proofman_util::{timer_start_info, timer_stop_and_log_info}; use witness::WitnessComponent; -use proofman_common::{AirInstance, ModeName, ProofCtx, SetupCtx}; +use proofman_common::{AirInstance, ProofCtx, SetupCtx}; use proofman_hints::{ get_hint_field_gc_constant_a, get_hint_field, get_hint_field_a, acc_mul_hint_fields, update_airgroupvalue, get_hint_ids_by_name, HintFieldOptions, HintFieldValue, HintFieldValuesVec, }; use crate::{ - extract_field_element_as_usize, get_global_hint_field_constant_as, get_hint_field_constant_a_as_string, - get_hint_field_constant_as_field, get_hint_field_constant_as_string, get_row_field_value, print_debug_info, - update_debug_data, AirComponent, DebugData, + check_invalid_opids, extract_field_element_as_usize, get_global_hint_field_constant_as, + get_hint_field_constant_a_as_string, get_hint_field_constant_as_field, get_hint_field_constant_as_string, + get_row_field_value, print_debug_info, update_debug_data, update_debug_data_fast, AirComponent, DebugData, + DebugDataFast, SharedDataFast, }; pub struct StdProd { - pctx: Arc>, stage_wc: Option>, - debug_data: Option>, + _phantom: std::marker::PhantomData, } impl AirComponent for StdProd { const MY_NAME: &'static str = "STD Prod"; fn new( - pctx: Arc>, - sctx: Arc, + _pctx: Arc>, + sctx: Arc>, _airgroup_id: Option, _air_id: Option, ) -> Arc { @@ -39,7 +42,6 @@ impl AirComponent for StdProd { // Initialize std_prod with the extracted data Arc::new(Self { - pctx: pctx.clone(), stage_wc: match std_prod_users_id.is_empty() { true => None, false => { @@ -49,24 +51,26 @@ impl AirComponent for StdProd { Some(Mutex::new(stage_wc)) } }, - debug_data: if pctx.options.debug_info.std_mode.name == ModeName::Debug { - Some(Mutex::new(HashMap::new())) - } else { - None - }, + _phantom: std::marker::PhantomData, }) } +} +impl StdProd { + const MY_NAME: &'static str = "STD Prod"; + #[allow(clippy::too_many_arguments)] fn debug_mode( &self, pctx: &ProofCtx, - sctx: &SetupCtx, - air_instance: &mut AirInstance, + sctx: &SetupCtx, + air_instance: &AirInstance, air_instance_id: usize, num_rows: usize, debug_data_hints: Vec, + debug_data: &mut DebugData, + debug_data_fast: &mut DebugDataFast, + fast_mode: bool, ) { - let debug_data = self.debug_data.as_ref().expect("Debug data missing"); let airgroup_id = air_instance.airgroup_id; let air_id = air_instance.air_id; @@ -168,7 +172,9 @@ impl AirComponent for StdProd { &expressions, 0, debug_data, + debug_data_fast, is_global.is_one(), + fast_mode, ); } // Otherwise, update the bus for each row @@ -186,7 +192,9 @@ impl AirComponent for StdProd { &expressions, j, debug_data, + debug_data_fast, false, + fast_mode, ); } } @@ -203,8 +211,10 @@ impl AirComponent for StdProd { sel: &HintFieldValue, expressions: &HintFieldValuesVec, row: usize, - debug_data: &DebugData, + debug_data: &mut DebugData, + debug_data_fast: &mut DebugDataFast, is_global: bool, + fast_mode: bool, ) { let mut sel = get_row_field_value(sel, row, "sel"); if sel.is_zero() { @@ -223,27 +233,31 @@ impl AirComponent for StdProd { _ => panic!("Proves hint must be either 0, 1, or -1"), }; - update_debug_data( - debug_data, - name_piop, - name_expr, - opid, - expressions.get(row), - airgroup_id, - air_id, - air_instance_id, - row, - proves, - sel, - is_global, - ); + if fast_mode { + update_debug_data_fast(debug_data_fast, opid, expressions.get(row), proves, sel, is_global); + } else { + update_debug_data( + debug_data, + name_piop, + name_expr, + opid, + expressions.get(row), + airgroup_id, + air_id, + air_instance_id, + row, + proves, + sel, + is_global, + ); + } } } } } impl WitnessComponent for StdProd { - fn calculate_witness(&self, stage: u32, pctx: Arc>, sctx: Arc) { + fn calculate_witness(&self, stage: u32, pctx: Arc>, sctx: Arc>) { let stage_wc = self.stage_wc.as_ref(); if stage_wc.is_none() { return; @@ -283,23 +297,7 @@ impl WitnessComponent for StdProd { log::debug!("{}: ··· Computing witness for AIR '{}' at stage {}", Self::MY_NAME, air_name, stage); - let num_rows = pctx.global_info.airs[airgroup_id][air_id].num_rows; - let gprod_hints = get_hint_ids_by_name(p_expressions_bin, "gprod_col"); - let debug_data_hints = get_hint_ids_by_name(p_expressions_bin, "gprod_debug_data"); - - // Debugging, if enabled - if pctx.options.debug_info.std_mode.name == ModeName::Debug { - let air_instance_id = pctx.dctx_find_air_instance_id(global_instance_id); - self.debug_mode( - &pctx, - &sctx, - air_instance, - air_instance_id, - num_rows, - debug_data_hints.clone(), - ); - } // We know that at most one product hint exists let gprod_hint = if gprod_hints.len() > 1 { @@ -344,15 +342,140 @@ impl WitnessComponent for StdProd { } } - fn end_proof(&self) { - // Print debug info if in debug mode - if self.pctx.options.debug_info.std_mode.name == ModeName::Debug { - let pctx = &self.pctx; - let name = Self::MY_NAME; - let max_values_to_print = pctx.options.debug_info.std_mode.n_vals; - let print_to_file = pctx.options.debug_info.std_mode.print_to_file; - let debug_data = self.debug_data.as_ref().expect("Debug data missing"); - print_debug_info(pctx, name, max_values_to_print, print_to_file, debug_data); + fn debug(&self, pctx: Arc>, sctx: Arc>) { + timer_start_info!(DEBUG_MODE_PROD); + + let std_prod_users_vec = get_hint_ids_by_name(sctx.get_global_bin(), "std_prod_users"); + + if !std_prod_users_vec.is_empty() { + let std_prod_users = std_prod_users_vec[0]; + + let num_users = get_global_hint_field_constant_as::(sctx.clone(), std_prod_users, "num_users"); + let airgroup_ids = get_hint_field_gc_constant_a::(sctx.clone(), std_prod_users, "airgroup_ids", false); + let air_ids = get_hint_field_gc_constant_a::(sctx.clone(), std_prod_users, "air_ids", false); + + let fast_mode = pctx.options.debug_info.std_mode.fast_mode; + + let mut debug_data = HashMap::new(); + + let mut debugs_data_fasts: Vec> = Vec::new(); + + let mut global_instance_ids = Vec::new(); + + for i in 0..num_users { + let airgroup_id = extract_field_element_as_usize(&airgroup_ids.values[i], "airgroup_id"); + let air_id = extract_field_element_as_usize(&air_ids.values[i], "air_id"); + + // Get all air instances ids for this airgroup and air_id + let global_ids = pctx.air_instance_repo.find_air_instances(airgroup_id, air_id); + + for global_instance_id in global_ids { + // Retrieve all air instances + let air_instances = &mut pctx.air_instance_repo.air_instances.read().unwrap(); + let air_instance = air_instances.get(&global_instance_id).unwrap(); + + if air_instance.prover_initialized { + global_instance_ids.push(global_instance_id); + } + } + } + + if fast_mode { + // Process each sum check user + debugs_data_fasts = global_instance_ids + .par_iter() + .map(|&global_instance_id| { + let mut local_debug_data_fast = HashMap::new(); + + // Retrieve all air instances + let air_instances = &mut pctx.air_instance_repo.air_instances.read().unwrap(); + + let air_instance = air_instances.get(&global_instance_id).unwrap(); + let air_instance_id = pctx.dctx_find_air_instance_id(global_instance_id); + let air_name = &pctx.global_info.airs[air_instance.airgroup_id][air_instance.air_id].name; + + log::debug!( + "{}: ··· Checking debug mode fast for instance_id {} of {}", + Self::MY_NAME, + air_instance_id, + air_name + ); + + // Get the air associated with the air_instance + let airgroup_id = air_instance.airgroup_id; + let air_id = air_instance.air_id; + + let setup = sctx.get_setup(airgroup_id, air_id); + let p_expressions_bin = setup.p_setup.p_expressions_bin; + + let num_rows = pctx.global_info.airs[airgroup_id][air_id].num_rows; + + let debug_data_hints = get_hint_ids_by_name(p_expressions_bin, "gprod_debug_data"); + + self.debug_mode( + &pctx, + &sctx, + air_instance, + air_instance_id, + num_rows, + debug_data_hints.clone(), + &mut HashMap::new(), + &mut local_debug_data_fast, + true, + ); + + local_debug_data_fast + }) + .collect(); + } else { + // Process each sum check user + for global_instance_id in global_instance_ids { + // Retrieve all air instances + let air_instances = &mut pctx.air_instance_repo.air_instances.read().unwrap(); + let air_instance = air_instances.get(&global_instance_id).unwrap(); + let air_instance_id = pctx.dctx_find_air_instance_id(global_instance_id); + let air_name = &pctx.global_info.airs[air_instance.airgroup_id][air_instance.air_id].name; + + log::debug!( + "{}: ··· Checking debug mode for instance_id {} of {}", + Self::MY_NAME, + air_instance_id, + air_name + ); + + // Get the air associated with the air_instance + let airgroup_id = air_instance.airgroup_id; + let air_id = air_instance.air_id; + + let setup = sctx.get_setup(airgroup_id, air_id); + let p_expressions_bin = setup.p_setup.p_expressions_bin; + + let num_rows = pctx.global_info.airs[airgroup_id][air_id].num_rows; + + let debug_data_hints = get_hint_ids_by_name(p_expressions_bin, "gprod_debug_data"); + + self.debug_mode( + &pctx, + &sctx, + air_instance, + air_instance_id, + num_rows, + debug_data_hints.clone(), + &mut debug_data, + &mut HashMap::new(), + false, + ); + } + } + + if fast_mode { + check_invalid_opids(&pctx, Self::MY_NAME, &mut debugs_data_fasts); + } else { + let max_values_to_print = pctx.options.debug_info.std_mode.n_vals; + let print_to_file = pctx.options.debug_info.std_mode.print_to_file; + print_debug_info(&pctx, Self::MY_NAME, max_values_to_print, print_to_file, &mut debug_data); + } } + timer_stop_and_log_info!(DEBUG_MODE_PROD); } } diff --git a/pil2-components/lib/std/rs/src/std_sum.rs b/pil2-components/lib/std/rs/src/std_sum.rs index ef7c5344e..50a09e96d 100644 --- a/pil2-components/lib/std/rs/src/std_sum.rs +++ b/pil2-components/lib/std/rs/src/std_sum.rs @@ -3,34 +3,37 @@ use std::{ sync::{Arc, Mutex}, }; +use rayon::prelude::*; + use num_traits::ToPrimitive; use p3_field::PrimeField; +use proofman_util::{timer_start_info, timer_stop_and_log_info}; use witness::WitnessComponent; -use proofman_common::{AirInstance, ProofCtx, SetupCtx, ModeName}; +use proofman_common::{AirInstance, ProofCtx, SetupCtx}; use proofman_hints::{ get_hint_field_gc_constant_a, get_hint_field, get_hint_field_a, acc_mul_hint_fields, update_airgroupvalue, get_hint_ids_by_name, mul_hint_fields, HintFieldOptions, HintFieldOutput, HintFieldValue, HintFieldValuesVec, }; use crate::{ - extract_field_element_as_usize, get_global_hint_field_constant_as, get_hint_field_constant_a_as_string, - get_hint_field_constant_as_field, get_hint_field_constant_as_string, get_row_field_value, print_debug_info, - update_debug_data, AirComponent, DebugData, + check_invalid_opids, extract_field_element_as_usize, get_global_hint_field_constant_as, + get_hint_field_constant_a_as_string, get_hint_field_constant_as_field, get_hint_field_constant_as_string, + get_row_field_value, print_debug_info, update_debug_data, update_debug_data_fast, AirComponent, DebugData, + DebugDataFast, SharedDataFast, }; pub struct StdSum { - pctx: Arc>, stage_wc: Option>, - debug_data: Option>, + _phantom: std::marker::PhantomData, } impl AirComponent for StdSum { const MY_NAME: &'static str = "STD Sum "; fn new( - pctx: Arc>, - sctx: Arc, + _pctx: Arc>, + sctx: Arc>, _airgroup_id: Option, _air_id: Option, ) -> Arc { @@ -39,7 +42,6 @@ impl AirComponent for StdSum { // Initialize std_sum with the extracted data Arc::new(Self { - pctx: pctx.clone(), stage_wc: match std_sum_users_id.is_empty() { true => None, false => { @@ -49,24 +51,25 @@ impl AirComponent for StdSum { Some(Mutex::new(stage_wc)) } }, - debug_data: if pctx.options.debug_info.std_mode.name == ModeName::Debug { - Some(Mutex::new(HashMap::new())) - } else { - None - }, + _phantom: std::marker::PhantomData, }) } +} +impl StdSum { + #[allow(clippy::too_many_arguments)] fn debug_mode( &self, pctx: &ProofCtx, - sctx: &SetupCtx, - air_instance: &mut AirInstance, + sctx: &SetupCtx, + air_instance: &AirInstance, air_instance_id: usize, num_rows: usize, debug_data_hints: Vec, + debug_data: &mut DebugData, + debug_data_fast: &mut DebugDataFast, + fast_mode: bool, ) { - let debug_data = self.debug_data.as_ref().expect("Debug data missing"); let airgroup_id = air_instance.airgroup_id; let air_id = air_instance.air_id; @@ -165,7 +168,9 @@ impl AirComponent for StdSum { &expressions, 0, debug_data, + debug_data_fast, is_global.is_one(), + fast_mode, ); } // Otherwise, update the bus for each row @@ -200,7 +205,9 @@ impl AirComponent for StdSum { &expressions, j, debug_data, + debug_data_fast, false, + fast_mode, ); } } @@ -218,8 +225,10 @@ impl AirComponent for StdSum { mul: &HintFieldValue, expressions: &HintFieldValuesVec, row: usize, - debug_data: &DebugData, + debug_data: &mut DebugData, + debug_data_fast: &mut DebugDataFast, is_global: bool, + fast_mode: bool, ) { let mut mul = get_row_field_value(mul, row, "mul"); if mul.is_zero() { @@ -238,26 +247,30 @@ impl AirComponent for StdSum { _ => panic!("Proves hint must be either 0, 1, or -1"), }; - update_debug_data( - debug_data, - name_piop, - name_expr, - opid, - expressions.get(row), - airgroup_id, - air_id, - instance_id, - row, - proves, - mul, - is_global, - ); + if fast_mode { + update_debug_data_fast(debug_data_fast, opid, expressions.get(row), proves, mul, is_global); + } else { + update_debug_data( + debug_data, + name_piop, + name_expr, + opid, + expressions.get(row), + airgroup_id, + air_id, + instance_id, + row, + proves, + mul, + is_global, + ); + } } } } impl WitnessComponent for StdSum { - fn calculate_witness(&self, stage: u32, pctx: Arc>, sctx: Arc) { + fn calculate_witness(&self, stage: u32, pctx: Arc>, sctx: Arc>) { let stage_wc = self.stage_wc.as_ref(); if stage_wc.is_none() { return; @@ -297,24 +310,8 @@ impl WitnessComponent for StdSum { log::debug!("{}: ··· Computing witness for AIR '{}' at stage {}", Self::MY_NAME, air_name, stage); - let num_rows = pctx.global_info.airs[airgroup_id][air_id].num_rows; - let im_hints = get_hint_ids_by_name(p_expressions_bin, "im_col"); let gsum_hints = get_hint_ids_by_name(p_expressions_bin, "gsum_col"); - let debug_data_hints = get_hint_ids_by_name(p_expressions_bin, "gsum_debug_data"); - - // Debugging, if enabled - if pctx.options.debug_info.std_mode.name == ModeName::Debug { - let air_instance_id = pctx.dctx_find_air_instance_id(global_instance_id); - self.debug_mode( - &pctx, - &sctx, - air_instance, - air_instance_id, - num_rows, - debug_data_hints.clone(), - ); - } // Populate the im columns for hint in im_hints { @@ -374,15 +371,139 @@ impl WitnessComponent for StdSum { } } - fn end_proof(&self) { - // Print debug info if in debug mode - if self.pctx.options.debug_info.std_mode.name == ModeName::Debug { - let pctx = &self.pctx; - let name = Self::MY_NAME; - let max_values_to_print = pctx.options.debug_info.std_mode.n_vals; - let print_to_file = pctx.options.debug_info.std_mode.print_to_file; - let debug_data = self.debug_data.as_ref().expect("Debug data missing"); - print_debug_info(pctx, name, max_values_to_print, print_to_file, debug_data); + fn debug(&self, pctx: Arc>, sctx: Arc>) { + timer_start_info!(DEBUG_MODE_SUM); + let std_sum_users_vec = get_hint_ids_by_name(sctx.get_global_bin(), "std_sum_users"); + + if !std_sum_users_vec.is_empty() { + let std_sum_users = std_sum_users_vec[0]; + + let num_users = get_global_hint_field_constant_as::(sctx.clone(), std_sum_users, "num_users"); + let airgroup_ids = get_hint_field_gc_constant_a::(sctx.clone(), std_sum_users, "airgroup_ids", false); + let air_ids = get_hint_field_gc_constant_a::(sctx.clone(), std_sum_users, "air_ids", false); + + let fast_mode = pctx.options.debug_info.std_mode.fast_mode; + + let mut debug_data = HashMap::new(); + + let mut debugs_data_fasts: Vec> = Vec::new(); + + let mut global_instance_ids = Vec::new(); + + for i in 0..num_users { + let airgroup_id = extract_field_element_as_usize(&airgroup_ids.values[i], "airgroup_id"); + let air_id = extract_field_element_as_usize(&air_ids.values[i], "air_id"); + + // Get all air instances ids for this airgroup and air_id + let global_ids = pctx.air_instance_repo.find_air_instances(airgroup_id, air_id); + + for global_instance_id in global_ids { + // Retrieve all air instances + let air_instances = &mut pctx.air_instance_repo.air_instances.read().unwrap(); + let air_instance = air_instances.get(&global_instance_id).unwrap(); + + if air_instance.prover_initialized { + global_instance_ids.push(global_instance_id); + } + } + } + + if fast_mode { + // Process each sum check user + debugs_data_fasts = global_instance_ids + .par_iter() + .map(|&global_instance_id| { + let mut local_debug_data_fast = HashMap::new(); + + // Retrieve all air instances + let air_instances = &mut pctx.air_instance_repo.air_instances.read().unwrap(); + + let air_instance = air_instances.get(&global_instance_id).unwrap(); + let air_instance_id = pctx.dctx_find_air_instance_id(global_instance_id); + let air_name = &pctx.global_info.airs[air_instance.airgroup_id][air_instance.air_id].name; + + log::debug!( + "{}: ··· Checking debug mode fast for instance_id {} of {}", + Self::MY_NAME, + air_instance_id, + air_name + ); + + // Get the air associated with the air_instance + let airgroup_id = air_instance.airgroup_id; + let air_id = air_instance.air_id; + + let setup = sctx.get_setup(airgroup_id, air_id); + let p_expressions_bin = setup.p_setup.p_expressions_bin; + + let num_rows = pctx.global_info.airs[airgroup_id][air_id].num_rows; + + let debug_data_hints = get_hint_ids_by_name(p_expressions_bin, "gsum_debug_data"); + + self.debug_mode( + &pctx, + &sctx, + air_instance, + air_instance_id, + num_rows, + debug_data_hints.clone(), + &mut HashMap::new(), + &mut local_debug_data_fast, + true, + ); + + local_debug_data_fast + }) + .collect(); + } else { + // Process each sum check user + for global_instance_id in global_instance_ids { + // Retrieve all air instances + let air_instances = &mut pctx.air_instance_repo.air_instances.read().unwrap(); + let air_instance = air_instances.get(&global_instance_id).unwrap(); + let air_instance_id = pctx.dctx_find_air_instance_id(global_instance_id); + let air_name = &pctx.global_info.airs[air_instance.airgroup_id][air_instance.air_id].name; + + log::debug!( + "{}: ··· Checking debug mode for instance_id {} of {}", + Self::MY_NAME, + air_instance_id, + air_name + ); + + // Get the air associated with the air_instance + let airgroup_id = air_instance.airgroup_id; + let air_id = air_instance.air_id; + + let setup = sctx.get_setup(airgroup_id, air_id); + let p_expressions_bin = setup.p_setup.p_expressions_bin; + + let num_rows = pctx.global_info.airs[airgroup_id][air_id].num_rows; + + let debug_data_hints = get_hint_ids_by_name(p_expressions_bin, "gsum_debug_data"); + + self.debug_mode( + &pctx, + &sctx, + air_instance, + air_instance_id, + num_rows, + debug_data_hints.clone(), + &mut debug_data, + &mut HashMap::new(), + false, + ); + } + } + + if fast_mode { + check_invalid_opids(&pctx, Self::MY_NAME, &mut debugs_data_fasts); + } else { + let max_values_to_print = pctx.options.debug_info.std_mode.n_vals; + let print_to_file = pctx.options.debug_info.std_mode.print_to_file; + print_debug_info(&pctx, Self::MY_NAME, max_values_to_print, print_to_file, &mut debug_data); + } } + timer_stop_and_log_info!(DEBUG_MODE_SUM); } } diff --git a/pil2-components/test/simple/rs/Cargo.toml b/pil2-components/test/simple/rs/Cargo.toml index a202ebeba..10512bf58 100644 --- a/pil2-components/test/simple/rs/Cargo.toml +++ b/pil2-components/test/simple/rs/Cargo.toml @@ -21,6 +21,7 @@ rand.workspace = true num-bigint.workspace = true p3-goldilocks.workspace = true p3-field.workspace = true +rayon.workspace = true [features] default = [] diff --git a/pil2-components/test/simple/rs/src/pil_helpers/traces.rs b/pil2-components/test/simple/rs/src/pil_helpers/traces.rs index 012c51c22..7d4caa93d 100644 --- a/pil2-components/test/simple/rs/src/pil_helpers/traces.rs +++ b/pil2-components/test/simple/rs/src/pil_helpers/traces.rs @@ -1,5 +1,8 @@ // WARNING: This file has been autogenerated from the PILOUT file. // Manual modifications are not recommended and may be overwritten. +#![allow(clippy::all)] +#![allow(non_snake_case)] + use proofman_common as common; pub use proofman_macros::trace; pub use proofman_macros::values; @@ -22,10 +25,18 @@ pub const SIMPLE_LEFT_AIR_IDS: &[usize] = &[0]; pub const SIMPLE_RIGHT_AIR_IDS: &[usize] = &[1]; +trace!(SimpleLeftFixed { + __L1: F, +}, 0, 0, 4 ); + trace!(SimpleLeftTrace { a: F, b: F, c: F, d: F, e: F, f: F, g: F, h: F, }, 0, 0, 4 ); +trace!(SimpleRightFixed { + __L1: F, +}, 0, 1, 4 ); + trace!(SimpleRightTrace { a: F, b: F, c: F, d: F, mul: F, }, 0, 1, 4 ); diff --git a/pil2-components/test/std/connection/rs/Cargo.toml b/pil2-components/test/std/connection/rs/Cargo.toml index 35c238876..63b30577a 100644 --- a/pil2-components/test/std/connection/rs/Cargo.toml +++ b/pil2-components/test/std/connection/rs/Cargo.toml @@ -20,6 +20,7 @@ rand.workspace = true num-bigint.workspace = true p3-goldilocks.workspace = true p3-field.workspace = true +rayon.workspace = true [build-dependencies] proofman-cli.workspace = true diff --git a/pil2-components/test/std/connection/rs/src/pil_helpers/traces.rs b/pil2-components/test/std/connection/rs/src/pil_helpers/traces.rs index e452be9b7..b8a5f6d4d 100644 --- a/pil2-components/test/std/connection/rs/src/pil_helpers/traces.rs +++ b/pil2-components/test/std/connection/rs/src/pil_helpers/traces.rs @@ -1,5 +1,8 @@ // WARNING: This file has been autogenerated from the PILOUT file. // Manual modifications are not recommended and may be overwritten. +#![allow(clippy::all)] +#![allow(non_snake_case)] + use proofman_common as common; pub use proofman_macros::trace; pub use proofman_macros::values; @@ -24,26 +27,38 @@ pub const CONNECTION_2_AIR_IDS: &[usize] = &[1]; pub const CONNECTION_NEW_AIR_IDS: &[usize] = &[2]; +trace!(Connection1Fixed { + S1: F, S2: F, S3: F, ID: F, __L1__: F, +}, 0, 0, 8 ); + trace!(Connection1Trace { a: F, b: F, c: F, }, 0, 0, 8 ); +trace!(Connection2Fixed { + S1: F, S2: F, S3: F, ID: F, __L1__: F, +}, 0, 1, 16 ); + trace!(Connection2Trace { a: F, b: F, c: F, }, 0, 1, 16 ); +trace!(ConnectionNewFixed { + ID: F, CONN_2: [F; 3], CONN_3: [F; 3], CONN_4: [F; 3], CONN_5: [F; 3], CONN_6: [F; 4], CONN_7: [F; 4], __L1__: F, +}, 0, 2, 16 ); + trace!(ConnectionNewTrace { a: [F; 6], b: [F; 6], c: [F; 6], d: [F; 6], }, 0, 2, 16 ); values!(Connection1AirGroupValues { - gsum_result: FieldExtension, + gprod_result: FieldExtension, }); values!(Connection2AirGroupValues { - gsum_result: FieldExtension, + gprod_result: FieldExtension, }); values!(ConnectionNewAirGroupValues { - gsum_result: FieldExtension, + gprod_result: FieldExtension, }); diff --git a/pil2-components/test/std/diff_buses/rs/Cargo.toml b/pil2-components/test/std/diff_buses/rs/Cargo.toml index ab5ceb2f4..879185a7d 100644 --- a/pil2-components/test/std/diff_buses/rs/Cargo.toml +++ b/pil2-components/test/std/diff_buses/rs/Cargo.toml @@ -21,6 +21,7 @@ rand.workspace = true num-bigint.workspace = true p3-goldilocks.workspace = true p3-field.workspace = true +rayon.workspace = true [build-dependencies] proofman-cli.workspace = true diff --git a/pil2-components/test/std/direct_update/rs/Cargo.toml b/pil2-components/test/std/direct_update/rs/Cargo.toml index 1130e7863..a7f85e47c 100644 --- a/pil2-components/test/std/direct_update/rs/Cargo.toml +++ b/pil2-components/test/std/direct_update/rs/Cargo.toml @@ -22,7 +22,9 @@ num-bigint.workspace = true p3-goldilocks.workspace = true p3-field.workspace = true serde.workspace = true +serde_arrays.workspace = true serde_json.workspace = true +rayon.workspace = true [build-dependencies] proofman-cli.workspace = true diff --git a/pil2-components/test/std/lookup/rs/Cargo.toml b/pil2-components/test/std/lookup/rs/Cargo.toml index ee7be23ae..268ff6bfb 100644 --- a/pil2-components/test/std/lookup/rs/Cargo.toml +++ b/pil2-components/test/std/lookup/rs/Cargo.toml @@ -20,6 +20,7 @@ rand.workspace = true num-bigint.workspace = true p3-goldilocks.workspace = true p3-field.workspace = true +rayon.workspace = true [build-dependencies] proofman-cli.workspace = true diff --git a/pil2-components/test/std/lookup/rs/src/pil_helpers/traces.rs b/pil2-components/test/std/lookup/rs/src/pil_helpers/traces.rs index e51298777..904db913e 100644 --- a/pil2-components/test/std/lookup/rs/src/pil_helpers/traces.rs +++ b/pil2-components/test/std/lookup/rs/src/pil_helpers/traces.rs @@ -1,5 +1,8 @@ // WARNING: This file has been autogenerated from the PILOUT file. // Manual modifications are not recommended and may be overwritten. +#![allow(clippy::all)] +#![allow(non_snake_case)] + use proofman_common as common; pub use proofman_macros::trace; pub use proofman_macros::values; @@ -30,26 +33,50 @@ pub const LOOKUP_2_15_AIR_IDS: &[usize] = &[4]; pub const LOOKUP_3_AIR_IDS: &[usize] = &[5]; +trace!(Lookup0Fixed { + __L1__: F, +}, 0, 0, 1024 ); + trace!(Lookup0Trace { f: [F; 4], t: [F; 4], sel: [F; 2], mul: [F; 2], }, 0, 0, 1024 ); +trace!(Lookup1Fixed { + __L1__: F, +}, 0, 1, 1024 ); + trace!(Lookup1Trace { f: [F; 2], t: F, sel: [F; 2], mul: F, }, 0, 1, 1024 ); +trace!(Lookup2_12Fixed { + __L1__: F, +}, 0, 2, 4096 ); + trace!(Lookup2_12Trace { a1: F, b1: F, a2: F, b2: F, a3: F, b3: F, a4: F, b4: F, c1: F, d1: F, c2: F, d2: F, sel1: F, sel2: F, mul: F, }, 0, 2, 4096 ); +trace!(Lookup2_13Fixed { + __L1__: F, +}, 0, 3, 8192 ); + trace!(Lookup2_13Trace { a1: F, b1: F, a2: F, b2: F, a3: F, b3: F, a4: F, b4: F, c1: F, d1: F, c2: F, d2: F, sel1: F, sel2: F, mul: F, }, 0, 3, 8192 ); +trace!(Lookup2_15Fixed { + __L1__: F, +}, 0, 4, 32768 ); + trace!(Lookup2_15Trace { a1: F, b1: F, a2: F, b2: F, a3: F, b3: F, a4: F, b4: F, c1: F, d1: F, c2: F, d2: F, sel1: F, sel2: F, mul: F, }, 0, 4, 32768 ); +trace!(Lookup3Fixed { + __L1__: F, +}, 0, 5, 16384 ); + trace!(Lookup3Trace { c1: F, d1: F, c2: F, d2: F, mul1: F, mul2: F, }, 0, 5, 16384 ); diff --git a/pil2-components/test/std/permutation/rs/Cargo.toml b/pil2-components/test/std/permutation/rs/Cargo.toml index 14065161f..59d8f9f53 100644 --- a/pil2-components/test/std/permutation/rs/Cargo.toml +++ b/pil2-components/test/std/permutation/rs/Cargo.toml @@ -20,6 +20,7 @@ rand.workspace = true num-bigint.workspace = true p3-goldilocks.workspace = true p3-field.workspace = true +rayon.workspace = true [build-dependencies] proofman-cli.workspace = true diff --git a/pil2-components/test/std/permutation/rs/src/pil_helpers/traces.rs b/pil2-components/test/std/permutation/rs/src/pil_helpers/traces.rs index 2cd6ebc53..f710d62d1 100644 --- a/pil2-components/test/std/permutation/rs/src/pil_helpers/traces.rs +++ b/pil2-components/test/std/permutation/rs/src/pil_helpers/traces.rs @@ -1,5 +1,8 @@ // WARNING: This file has been autogenerated from the PILOUT file. // Manual modifications are not recommended and may be overwritten. +#![allow(clippy::all)] +#![allow(non_snake_case)] + use proofman_common as common; pub use proofman_macros::trace; pub use proofman_macros::values; @@ -26,34 +29,50 @@ pub const PERMUTATION_1_8_AIR_IDS: &[usize] = &[2]; pub const PERMUTATION_2_6_AIR_IDS: &[usize] = &[3]; +trace!(Permutation1_6Fixed { + __L1__: F, +}, 0, 0, 64 ); + trace!(Permutation1_6Trace { a1: F, b1: F, a2: F, b2: F, a3: F, b3: F, a4: F, b4: F, c1: F, d1: F, c2: F, d2: F, sel1: F, sel2: F, sel3: F, }, 0, 0, 64 ); +trace!(Permutation1_7Fixed { + __L1__: F, +}, 0, 1, 128 ); + trace!(Permutation1_7Trace { a1: F, b1: F, a2: F, b2: F, a3: F, b3: F, a4: F, b4: F, c1: F, d1: F, c2: F, d2: F, sel1: F, sel2: F, sel3: F, }, 0, 1, 128 ); +trace!(Permutation1_8Fixed { + __L1__: F, +}, 0, 2, 256 ); + trace!(Permutation1_8Trace { a1: F, b1: F, a2: F, b2: F, a3: F, b3: F, a4: F, b4: F, c1: F, d1: F, c2: F, d2: F, sel1: F, sel2: F, sel3: F, }, 0, 2, 256 ); +trace!(Permutation2_6Fixed { + __L1__: F, +}, 0, 3, 512 ); + trace!(Permutation2_6Trace { c1: F, d1: F, c2: F, d2: F, sel: F, }, 0, 3, 512 ); values!(Permutation1_6AirGroupValues { - gsum_result: FieldExtension, + gsum_result: FieldExtension, gprod_result: FieldExtension, }); values!(Permutation1_7AirGroupValues { - gsum_result: FieldExtension, + gsum_result: FieldExtension, gprod_result: FieldExtension, }); values!(Permutation1_8AirGroupValues { - gsum_result: FieldExtension, + gsum_result: FieldExtension, gprod_result: FieldExtension, }); values!(Permutation2_6AirGroupValues { - gsum_result: FieldExtension, + gsum_result: FieldExtension, gprod_result: FieldExtension, }); diff --git a/pil2-components/test/std/range_check/rs/Cargo.toml b/pil2-components/test/std/range_check/rs/Cargo.toml index 7e64129e2..095891cef 100644 --- a/pil2-components/test/std/range_check/rs/Cargo.toml +++ b/pil2-components/test/std/range_check/rs/Cargo.toml @@ -21,6 +21,7 @@ rand.workspace = true num-bigint.workspace = true p3-goldilocks.workspace = true p3-field.workspace = true +rayon.workspace = true [build-dependencies] proofman-cli.workspace = true diff --git a/pil2-components/test/std/range_check/rs/src/pil_helpers/traces.rs b/pil2-components/test/std/range_check/rs/src/pil_helpers/traces.rs index 565713a64..5a9060ff4 100644 --- a/pil2-components/test/std/range_check/rs/src/pil_helpers/traces.rs +++ b/pil2-components/test/std/range_check/rs/src/pil_helpers/traces.rs @@ -1,5 +1,8 @@ // WARNING: This file has been autogenerated from the PILOUT file. // Manual modifications are not recommended and may be overwritten. +#![allow(clippy::all)] +#![allow(non_snake_case)] + use proofman_common as common; pub use proofman_macros::trace; pub use proofman_macros::values; @@ -62,50 +65,98 @@ pub const U_8_AIR_AIR_IDS: &[usize] = &[0]; pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[0]; +trace!(RangeCheck3Fixed { + __L1__: F, +}, 0, 0, 32 ); + trace!(RangeCheck3Trace { c1: F, c2: F, }, 0, 0, 32 ); +trace!(RangeCheck2Fixed { + __L1__: F, +}, 1, 0, 16 ); + trace!(RangeCheck2Trace { b1: F, b2: F, b3: F, }, 1, 0, 16 ); +trace!(RangeCheck1Fixed { + __L1__: F, +}, 2, 0, 8 ); + trace!(RangeCheck1Trace { a1: F, a2: F, a3: F, a4: F, a5: F, sel1: F, sel2: F, sel3: F, }, 2, 0, 8 ); +trace!(RangeCheck4Fixed { + __L1__: F, +}, 3, 0, 64 ); + trace!(RangeCheck4Trace { a1: F, a2: F, a3: F, a4: F, a5: F, a6: F, a7: F, a8: F, sel1: F, sel2: F, }, 3, 0, 64 ); +trace!(U16AirFixed { + U16: F, __L1__: F, +}, 3, 1, 65536 ); + trace!(U16AirTrace { mul: F, }, 3, 1, 65536 ); +trace!(MultiRangeCheck1Fixed { + __L1__: F, +}, 4, 0, 8 ); + trace!(MultiRangeCheck1Trace { a: [F; 3], sel: [F; 3], range_sel: [F; 3], }, 4, 0, 8 ); +trace!(MultiRangeCheck2Fixed { + __L1__: F, +}, 5, 0, 16 ); + trace!(MultiRangeCheck2Trace { a: [F; 2], sel: [F; 2], range_sel: [F; 2], }, 5, 0, 16 ); +trace!(RangeCheckDynamic1Fixed { + __L1__: F, +}, 6, 0, 256 ); + trace!(RangeCheckDynamic1Trace { colu: F, sel_7: F, sel_8: F, sel_16: F, sel_17: F, }, 6, 0, 256 ); +trace!(RangeCheckDynamic2Fixed { + __L1__: F, +}, 7, 0, 64 ); + trace!(RangeCheckDynamic2Trace { colu: F, sel_1: F, sel_2: F, sel_3: F, sel_4: F, sel_5: F, }, 7, 0, 64 ); +trace!(RangeCheckMixFixed { + __L1__: F, +}, 8, 0, 64 ); + trace!(RangeCheckMixTrace { a: [F; 4], b: [F; 2], c: [F; 1], range_sel: [F; 5], }, 8, 0, 64 ); +trace!(U8AirFixed { + U8: F, __L1__: F, +}, 9, 0, 256 ); + trace!(U8AirTrace { mul: F, }, 9, 0, 256 ); +trace!(SpecifiedRangesFixed { + RANGE: [F; 20], __L1__: F, +}, 10, 0, 131072 ); + trace!(SpecifiedRangesTrace { mul: [F; 20], }, 10, 0, 131072 ); diff --git a/pil2-stark/lib/include/starks_lib.h b/pil2-stark/lib/include/starks_lib.h index 71d5e41fc..456031fa6 100644 --- a/pil2-stark/lib/include/starks_lib.h +++ b/pil2-stark/lib/include/starks_lib.h @@ -10,7 +10,7 @@ // FRIProof // ======================================================================================== - void *fri_proof_new(void *pSetupCtx, uint64_t instanceId); + void *fri_proof_new(void *pSetupCtx, uint64_t airgroupId, uint64_t airId, uint64_t instanceId); void fri_proof_get_tree_root(void *pFriProof, void* root, uint64_t tree_index); void fri_proof_set_airgroupvalues(void *pFriProof, void *airgroupValues); void fri_proof_set_airvalues(void *pFriProof, void *airValues); @@ -39,11 +39,12 @@ // Const Pols // ======================================================================================== - void load_const_tree(void *pConstTree, char *treeFilename, uint64_t constTreeSize); + bool load_const_tree(void *pStarkInfo, void *pConstTree, char *treeFilename, uint64_t constTreeSize, char* verkeyFilename); void load_const_pols(void *pConstPols, char *constFilename, uint64_t constSize); uint64_t get_const_tree_size(void *pStarkInfo); uint64_t get_const_size(void *pStarkInfo); - void calculate_const_tree(void *pStarkInfo, void *pConstPolsAddress, void *pConstTree, char *treeFilename); + void calculate_const_tree(void *pStarkInfo, void *pConstPolsAddress, void *pConstTree); + void write_const_tree(void *pStarkInfo, void *pConstTreeAddress, char *treeFilename); // Expressions Bin // ======================================================================================== @@ -68,7 +69,6 @@ void starks_free(void *pStarks); void treesGL_get_root(void *pStarks, uint64_t index, void *root); - void treesGL_set_root(void *pStarks, uint64_t index, void *pProof); void calculate_xdivxsub(void *pStarks, void* xiChallenge, void *xDivXSub); void *get_fri_pol(void *pStarkInfo, void *buffer); @@ -129,7 +129,7 @@ // Recursive proof // ================================================================================= - void *gen_recursive_proof(void *pSetupCtx, char* globalInfoFile, uint64_t airgroupId, void* witness, void* aux_trace, void *pConstPols, void *pConstTree, void* pPublicInputs, char *proof_file, bool vadcop); + void *gen_recursive_proof(void *pSetupCtx, char* globalInfoFile, uint64_t airgroupId, uint64_t airId, uint64_t instanceId, void* witness, void* aux_trace, void *pConstPols, void *pConstTree, void* pPublicInputs, char *proof_file, bool vadcop); void *get_zkin_ptr(char *zkin_file); void *add_recursive2_verkey(void *pZkin, char* recursive2VerKeyFilename); void *join_zkin_recursive2(char* globalInfoFile, uint64_t airgroupId, void* pPublics, void* pChallenges, void *zkin1, void *zkin2, void *starkInfoRecursive2); @@ -160,5 +160,14 @@ void *create_buffer(uint64_t size); void free_buffer(void *buffer); + + // Fixed cols + // ================================================================================= + void write_fixed_cols_bin(char* binFile, char* airgroupName, char* airName, uint64_t N, uint64_t nFixedPols, void* fixedPolsInfo); + + // OMP + // ================================================================================= + uint64_t get_omp_max_threads(); + void set_omp_num_threads(uint64_t num_threads); #endif \ No newline at end of file diff --git a/pil2-stark/src/api/starks_api.cpp b/pil2-stark/src/api/starks_api.cpp index 60aa1fb70..03ff54bc2 100644 --- a/pil2-stark/src/api/starks_api.cpp +++ b/pil2-stark/src/api/starks_api.cpp @@ -10,6 +10,7 @@ #include "setup_ctx.hpp" #include "stark_verify.hpp" #include "exec_file.hpp" +#include "fixed_cols.hpp" #include "final_snark_proof.hpp" #include @@ -72,10 +73,10 @@ void save_proof_values(void *pProofValues, char* globalInfoFile, char *fileDir) -void *fri_proof_new(void *pSetupCtx, uint64_t instanceId) +void *fri_proof_new(void *pSetupCtx, uint64_t airgroupId, uint64_t airId, uint64_t instanceId) { SetupCtx setupCtx = *(SetupCtx *)pSetupCtx; - FRIProof *friProof = new FRIProof(setupCtx.starkInfo, instanceId); + FRIProof *friProof = new FRIProof(setupCtx.starkInfo, airgroupId, airId, instanceId); return friProof; } @@ -147,12 +148,10 @@ void fri_proof_get_zkinproofs(uint64_t nProofs, void **proofs, void **pFriProofs zkin["proofvalues"] = j["proofvalues"]; zkin["challenges"] = j["challenges"]["challenges"]; zkin["challengesFRISteps"] = j["challenges"]["challengesFRISteps"]; - - std::string airName = globalInfo["airs"][friProof->airgroupId][friProof->airId]["name"]; - std::string proofName = airName + "_" + std::to_string(friProof->instanceId); - if(!string(fileDir).empty()) { - json2file(zkin, string(fileDir) + "/zkin/proof_" + proofName + "_zkin.json"); + std::string airName = globalInfo["airs"][friProof->airgroupId][friProof->airId]["name"]; + std::string proofName = airName + "_" + std::to_string(friProof->instanceId); + json2file(zkin, string(fileDir) + "/proofs/proof_" + proofName + "_zkin.json"); } proofs[i] = (void *) new nlohmann::json(zkin); @@ -285,9 +284,10 @@ void prover_helpers_free(void *pProverHelpers) { // Const Pols // ======================================================================================== -void load_const_tree(void *pConstTree, char *treeFilename, uint64_t constTreeSize) { +bool load_const_tree(void *pStarkInfo, void *pConstTree, char *treeFilename, uint64_t constTreeSize, char* verkeyFilename) { ConstTree constTree; - constTree.loadConstTree(pConstTree, treeFilename, constTreeSize); + auto starkInfo = *(StarkInfo *)pStarkInfo; + return constTree.loadConstTree(starkInfo, pConstTree, treeFilename, constTreeSize, verkeyFilename); }; void load_const_pols(void *pConstPols, char *constFilename, uint64_t constSize) { @@ -299,9 +299,9 @@ uint64_t get_const_tree_size(void *pStarkInfo) { ConstTree constTree; auto starkInfo = *(StarkInfo *)pStarkInfo; if(starkInfo.starkStruct.verificationHashType == "GL") { - return constTree.getConstTreeSizeBytesGL(starkInfo); + return constTree.getConstTreeSizeGL(starkInfo); } else { - return constTree.getConstTreeSizeBytesBN128(starkInfo); + return constTree.getConstTreeSizeBN128(starkInfo); } }; @@ -309,17 +309,27 @@ uint64_t get_const_tree_size(void *pStarkInfo) { uint64_t get_const_size(void *pStarkInfo) { auto starkInfo = *(StarkInfo *)pStarkInfo; uint64_t N = 1 << starkInfo.starkStruct.nBits; - return N * starkInfo.nConstants * sizeof(Goldilocks::Element); + return N * starkInfo.nConstants; } -void calculate_const_tree(void *pStarkInfo, void *pConstPolsAddress, void *pConstTreeAddress, char *treeFilename) { +void calculate_const_tree(void *pStarkInfo, void *pConstPolsAddress, void *pConstTreeAddress) { + ConstTree constTree; + auto starkInfo = *(StarkInfo *)pStarkInfo; + if(starkInfo.starkStruct.verificationHashType == "GL") { + constTree.calculateConstTreeGL(*(StarkInfo *)pStarkInfo, (Goldilocks::Element *)pConstPolsAddress, pConstTreeAddress); + } else { + constTree.calculateConstTreeBN128(*(StarkInfo *)pStarkInfo, (Goldilocks::Element *)pConstPolsAddress, pConstTreeAddress); + } +}; + +void write_const_tree(void *pStarkInfo, void *pConstTreeAddress, char *treeFilename) { ConstTree constTree; auto starkInfo = *(StarkInfo *)pStarkInfo; if(starkInfo.starkStruct.verificationHashType == "GL") { - constTree.calculateConstTreeGL(*(StarkInfo *)pStarkInfo, (Goldilocks::Element *)pConstPolsAddress, pConstTreeAddress, treeFilename); + constTree.writeConstTreeFileGL(*(StarkInfo *)pStarkInfo, pConstTreeAddress, treeFilename); } else { - constTree.calculateConstTreeBN128(*(StarkInfo *)pStarkInfo, (Goldilocks::Element *)pConstPolsAddress, pConstTreeAddress, treeFilename); + constTree.writeConstTreeFileBN128(*(StarkInfo *)pStarkInfo, pConstTreeAddress, treeFilename); } }; @@ -400,14 +410,6 @@ void treesGL_get_root(void *pStarks, uint64_t index, void *dst) starks->ffi_treesGL_get_root(index, (Goldilocks::Element *)dst); } -void treesGL_set_root(void *pStarks, uint64_t index, void *pProof) -{ - Starks *starks = (Starks *)pStarks; - - starks->ffi_treesGL_set_root(index, *(FRIProof *)pProof); -} - - void calculate_fri_polynomial(void *pStarks, void* stepsParams) { Starks *starks = (Starks *)pStarks; @@ -666,15 +668,15 @@ void print_row(void *pSetupCtx, void *buffer, uint64_t stage, uint64_t row) { // Recursive proof // ================================================================================= -void *gen_recursive_proof(void *pSetupCtx, char* globalInfoFile, uint64_t airgroupId, void* witness, void* aux_trace, void *pConstPols, void *pConstTree, void* pPublicInputs, char* proof_file, bool vadcop) { +void *gen_recursive_proof(void *pSetupCtx, char* globalInfoFile, uint64_t airgroupId, uint64_t airId, uint64_t instanceId, void* witness, void* aux_trace, void *pConstPols, void *pConstTree, void* pPublicInputs, char* proof_file, bool vadcop) { json globalInfo; file2json(globalInfoFile, globalInfo); auto setup = *(SetupCtx *)pSetupCtx; if(setup.starkInfo.starkStruct.verificationHashType == "GL") { - return genRecursiveProof(*(SetupCtx *)pSetupCtx, globalInfo, airgroupId, (Goldilocks::Element *)witness, (Goldilocks::Element *)aux_trace, (Goldilocks::Element *)pConstPols, (Goldilocks::Element *)pConstTree, (Goldilocks::Element *)pPublicInputs, string(proof_file), vadcop); + return genRecursiveProof(*(SetupCtx *)pSetupCtx, globalInfo, airgroupId, airId, instanceId, (Goldilocks::Element *)witness, (Goldilocks::Element *)aux_trace, (Goldilocks::Element *)pConstPols, (Goldilocks::Element *)pConstTree, (Goldilocks::Element *)pPublicInputs, string(proof_file), vadcop); } else { - return genRecursiveProof(*(SetupCtx *)pSetupCtx, globalInfo, airgroupId, (Goldilocks::Element *)witness, (Goldilocks::Element *)aux_trace, (Goldilocks::Element *)pConstPols, (Goldilocks::Element *)pConstTree, (Goldilocks::Element *)pPublicInputs, string(proof_file), false); + return genRecursiveProof(*(SetupCtx *)pSetupCtx, globalInfo, airgroupId, airId, instanceId, (Goldilocks::Element *)witness, (Goldilocks::Element *)aux_trace, (Goldilocks::Element *)pConstPols, (Goldilocks::Element *)pConstTree, (Goldilocks::Element *)pPublicInputs, string(proof_file), false); } } @@ -847,4 +849,18 @@ void *create_buffer(uint64_t size) { void free_buffer(void *buffer) { cout << (Goldilocks::Element *)buffer << endl; delete[] (Goldilocks::Element *)buffer; +} + +// Fixed cols +// ================================================================================= +void write_fixed_cols_bin(char* binFile, char* airgroupName, char* airName, uint64_t N, uint64_t nFixedPols, void* fixedPolsInfo) { + writeFixedColsBin(string(binFile), string(airgroupName), string(airName), N, nFixedPols, (FixedPolsInfo *)fixedPolsInfo); +} + +uint64_t get_omp_max_threads(){ + return omp_get_max_threads(); +} + +void set_omp_num_threads(uint64_t num_threads){ + omp_set_num_threads(num_threads); } \ No newline at end of file diff --git a/pil2-stark/src/api/starks_api.hpp b/pil2-stark/src/api/starks_api.hpp index 71d5e41fc..456031fa6 100644 --- a/pil2-stark/src/api/starks_api.hpp +++ b/pil2-stark/src/api/starks_api.hpp @@ -10,7 +10,7 @@ // FRIProof // ======================================================================================== - void *fri_proof_new(void *pSetupCtx, uint64_t instanceId); + void *fri_proof_new(void *pSetupCtx, uint64_t airgroupId, uint64_t airId, uint64_t instanceId); void fri_proof_get_tree_root(void *pFriProof, void* root, uint64_t tree_index); void fri_proof_set_airgroupvalues(void *pFriProof, void *airgroupValues); void fri_proof_set_airvalues(void *pFriProof, void *airValues); @@ -39,11 +39,12 @@ // Const Pols // ======================================================================================== - void load_const_tree(void *pConstTree, char *treeFilename, uint64_t constTreeSize); + bool load_const_tree(void *pStarkInfo, void *pConstTree, char *treeFilename, uint64_t constTreeSize, char* verkeyFilename); void load_const_pols(void *pConstPols, char *constFilename, uint64_t constSize); uint64_t get_const_tree_size(void *pStarkInfo); uint64_t get_const_size(void *pStarkInfo); - void calculate_const_tree(void *pStarkInfo, void *pConstPolsAddress, void *pConstTree, char *treeFilename); + void calculate_const_tree(void *pStarkInfo, void *pConstPolsAddress, void *pConstTree); + void write_const_tree(void *pStarkInfo, void *pConstTreeAddress, char *treeFilename); // Expressions Bin // ======================================================================================== @@ -68,7 +69,6 @@ void starks_free(void *pStarks); void treesGL_get_root(void *pStarks, uint64_t index, void *root); - void treesGL_set_root(void *pStarks, uint64_t index, void *pProof); void calculate_xdivxsub(void *pStarks, void* xiChallenge, void *xDivXSub); void *get_fri_pol(void *pStarkInfo, void *buffer); @@ -129,7 +129,7 @@ // Recursive proof // ================================================================================= - void *gen_recursive_proof(void *pSetupCtx, char* globalInfoFile, uint64_t airgroupId, void* witness, void* aux_trace, void *pConstPols, void *pConstTree, void* pPublicInputs, char *proof_file, bool vadcop); + void *gen_recursive_proof(void *pSetupCtx, char* globalInfoFile, uint64_t airgroupId, uint64_t airId, uint64_t instanceId, void* witness, void* aux_trace, void *pConstPols, void *pConstTree, void* pPublicInputs, char *proof_file, bool vadcop); void *get_zkin_ptr(char *zkin_file); void *add_recursive2_verkey(void *pZkin, char* recursive2VerKeyFilename); void *join_zkin_recursive2(char* globalInfoFile, uint64_t airgroupId, void* pPublics, void* pChallenges, void *zkin1, void *zkin2, void *starkInfoRecursive2); @@ -160,5 +160,14 @@ void *create_buffer(uint64_t size); void free_buffer(void *buffer); + + // Fixed cols + // ================================================================================= + void write_fixed_cols_bin(char* binFile, char* airgroupName, char* airName, uint64_t N, uint64_t nFixedPols, void* fixedPolsInfo); + + // OMP + // ================================================================================= + uint64_t get_omp_max_threads(); + void set_omp_num_threads(uint64_t num_threads); #endif \ No newline at end of file diff --git a/pil2-stark/src/config/zkglobals.cpp b/pil2-stark/src/config/zkglobals.cpp index 389aa9ba0..4b743e68a 100644 --- a/pil2-stark/src/config/zkglobals.cpp +++ b/pil2-stark/src/config/zkglobals.cpp @@ -1,7 +1,7 @@ #include "zkglobals.hpp" Goldilocks fr; -PoseidonGoldilocks poseidon; +Poseidon2Goldilocks poseidon; RawFec fec; RawFnec fnec; RawFr bn128; diff --git a/pil2-stark/src/config/zkglobals.hpp b/pil2-stark/src/config/zkglobals.hpp index 396629a98..b62f38020 100644 --- a/pil2-stark/src/config/zkglobals.hpp +++ b/pil2-stark/src/config/zkglobals.hpp @@ -2,14 +2,14 @@ #define ZKGLOBALS_HPP #include "goldilocks_base_field.hpp" -#include "poseidon_goldilocks.hpp" +#include "poseidon2_goldilocks.hpp" #include "ffiasm/fec.hpp" #include "ffiasm/fnec.hpp" #include "ffiasm/fr.hpp" #include "ffiasm/fq.hpp" extern Goldilocks fr; -extern PoseidonGoldilocks poseidon; +extern Poseidon2Goldilocks poseidon; extern RawFec fec; extern RawFnec fnec; extern RawFr bn128; diff --git a/pil2-stark/src/goldilocks/benchs/bench.cpp b/pil2-stark/src/goldilocks/benchs/bench.cpp index 278c8b1d2..4a922139a 100644 --- a/pil2-stark/src/goldilocks/benchs/bench.cpp +++ b/pil2-stark/src/goldilocks/benchs/bench.cpp @@ -3,6 +3,8 @@ #include "../src/goldilocks_base_field.hpp" #include "../src/poseidon_goldilocks.hpp" +#include "../src/poseidon2_goldilocks.hpp" +#include "../src/poseidon2_goldilocks_avx.hpp" #include "../src/poseidon_goldilocks_avx.hpp" #include "../src/ntt_goldilocks.hpp" #include "../src/merklehash_goldilocks.hpp" @@ -164,9 +166,79 @@ static void MUL_OP_AVX_BENCH(benchmark::State &state) Goldilocks::Element res[4]; Goldilocks::store_avx(res, term2_); assert(Goldilocks::toU64(res[0]) == 1922281271747280077ULL); +} + + +static void POSEIDON2_BENCH_FULL(benchmark::State &state) +{ + uint64_t input_size = (uint64_t)NUM_HASHES * (uint64_t)SPONGE_WIDTH; + Goldilocks::Element *x = new Goldilocks::Element[input_size]; + Goldilocks::Element *result = new Goldilocks::Element[input_size]; + + for (uint64_t i = 0; i < input_size; i++) + { + x[i] = Goldilocks::fromU64(i); + } + + // Benchmark + for (auto _ : state) + { +#pragma omp parallel for num_threads(state.range(0)) schedule(static) + for (uint64_t i = 0; i < NUM_HASHES; i++) + { + Poseidon2Goldilocks::hash_full_result_seq((Goldilocks::Element(&)[SPONGE_WIDTH])result[i * SPONGE_WIDTH], (Goldilocks::Element(&)[SPONGE_WIDTH])x[i * SPONGE_WIDTH]); + } + } + // Check poseidon results poseidon ( 0 1 2 3 4 5 6 7 8 9 10 11 ) + assert(Goldilocks::toU64(result[0]) == 0X1EAEF96BDF1C0C1 ); + assert(Goldilocks::toU64(result[1]) == 0X1F0D2CC525B2540C); + assert(Goldilocks::toU64(result[2]) == 0X6282C1DFE1E0358D); + assert(Goldilocks::toU64(result[3]) == 0XE780D721F698E1E6); + delete[] x; + delete[] result; + // Rate = time to process 1 posseidon per core + // BytesProcessed = total bytes processed per second on every iteration + int threads_core = 2 * state.range(0) / omp_get_max_threads(); // we assume hyperthreading + state.counters["Rate"] = benchmark::Counter(threads_core * (double)NUM_HASHES / (double)state.range(0), benchmark::Counter::kIsIterationInvariantRate | benchmark::Counter::kInvert); + state.counters["BytesProcessed"] = benchmark::Counter(input_size * sizeof(uint64_t), benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1024); +} + +static void POSEIDON2_BENCH_FULL_AVX(benchmark::State &state) +{ + uint64_t input_size = (uint64_t)NUM_HASHES * (uint64_t)SPONGE_WIDTH; + Goldilocks::Element *x = new Goldilocks::Element[input_size]; + Goldilocks::Element *result = new Goldilocks::Element[input_size]; + for (uint64_t i = 0; i < input_size; i++) + { + x[i] = Goldilocks::fromU64(i); + } + + // Benchmark + for (auto _ : state) + { +#pragma omp parallel for num_threads(state.range(0)) schedule(static) + for (uint64_t i = 0; i < NUM_HASHES; i++) + { + Poseidon2Goldilocks::hash_full_result((Goldilocks::Element(&)[SPONGE_WIDTH])result[i * SPONGE_WIDTH], (Goldilocks::Element(&)[SPONGE_WIDTH])x[i * SPONGE_WIDTH]); + } + } + // Check poseidon results poseidon ( 0 1 2 3 4 5 6 7 8 9 10 11 ) + // assert(Goldilocks::toU64(result[0]) == 0X1EAEF96BDF1C0C1 ); + // assert(Goldilocks::toU64(result[1]) == 0X1F0D2CC525B2540C); + // assert(Goldilocks::toU64(result[2]) == 0X6282C1DFE1E0358D); + // assert(Goldilocks::toU64(result[3]) == 0XE780D721F698E1E6); + delete[] x; + delete[] result; + // Rate = time to process 1 posseidon per core + // BytesProcessed = total bytes processed per second on every iteration + int threads_core = 2 * state.range(0) / omp_get_max_threads(); // we assume hyperthreading + state.counters["Rate"] = benchmark::Counter(threads_core * (double)NUM_HASHES / (double)state.range(0), benchmark::Counter::kIsIterationInvariantRate | benchmark::Counter::kInvert); + state.counters["BytesProcessed"] = benchmark::Counter(input_size * sizeof(uint64_t), benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1024); } + + static void POSEIDON_BENCH_FULL(benchmark::State &state) { uint64_t input_size = (uint64_t)NUM_HASHES * (uint64_t)SPONGE_WIDTH; @@ -290,6 +362,80 @@ static void POSEIDON_BENCH_FULL_AVX512(benchmark::State &state) } #endif + +static void POSEIDON2_BENCH(benchmark::State &state) +{ + uint64_t input_size = (uint64_t)NUM_HASHES * (uint64_t)SPONGE_WIDTH; + uint64_t output_size = (uint64_t)NUM_HASHES * (uint64_t)CAPACITY; + Goldilocks::Element *x = new Goldilocks::Element[input_size]; + Goldilocks::Element *result = new Goldilocks::Element[output_size]; + + for (uint64_t i = 0; i < input_size; i++) + { + x[i] = Goldilocks::fromU64(i); + } + + // Benchmark + for (auto _ : state) + { +#pragma omp parallel for num_threads(state.range(0)) schedule(static) + for (uint64_t i = 0; i < NUM_HASHES; i++) + { + Poseidon2Goldilocks::hash_seq((Goldilocks::Element(&)[CAPACITY])result[i * CAPACITY], (Goldilocks::Element(&)[SPONGE_WIDTH])x[i * SPONGE_WIDTH]); + } + } + // Check poseidon results poseidon ( 0 1 2 3 4 5 6 7 8 9 10 11 ) + assert(Goldilocks::toU64(result[0]) == 0X1EAEF96BDF1C0C1 ); + assert(Goldilocks::toU64(result[1]) == 0X1F0D2CC525B2540C); + assert(Goldilocks::toU64(result[2]) == 0X6282C1DFE1E0358D); + assert(Goldilocks::toU64(result[3]) == 0XE780D721F698E1E6); + + delete[] x; + delete[] result; + // Rate = time to process 1 posseidon per core + // BytesProcessed = total bytes processed per second on every iteration + int threads_core = 2 * state.range(0) / omp_get_max_threads(); // we assume hyperthreading + state.counters["Rate"] = benchmark::Counter(threads_core * (double)NUM_HASHES / (double)state.range(0), benchmark::Counter::kIsIterationInvariantRate | benchmark::Counter::kInvert); + state.counters["BytesProcessed"] = benchmark::Counter(input_size * sizeof(uint64_t), benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1024); +} + + +static void POSEIDON2_BENCH_AVX(benchmark::State &state) +{ + uint64_t input_size = (uint64_t)NUM_HASHES * (uint64_t)SPONGE_WIDTH; + uint64_t output_size = (uint64_t)NUM_HASHES * (uint64_t)CAPACITY; + Goldilocks::Element *x = new Goldilocks::Element[input_size]; + Goldilocks::Element *result = new Goldilocks::Element[output_size]; + + for (uint64_t i = 0; i < input_size; i++) + { + x[i] = Goldilocks::fromU64(i); + } + + // Benchmark + for (auto _ : state) + { +#pragma omp parallel for num_threads(state.range(0)) schedule(static) + for (uint64_t i = 0; i < NUM_HASHES; i++) + { + Poseidon2Goldilocks::hash((Goldilocks::Element(&)[CAPACITY])result[i * CAPACITY], (Goldilocks::Element(&)[SPONGE_WIDTH])x[i * SPONGE_WIDTH]); + } + } + // Check poseidon results poseidon ( 0 1 2 3 4 5 6 7 8 9 10 11 ) + // assert(Goldilocks::toU64(result[0]) == 0X1EAEF96BDF1C0C1 ); + // assert(Goldilocks::toU64(result[1]) == 0X1F0D2CC525B2540C); + // assert(Goldilocks::toU64(result[2]) == 0X6282C1DFE1E0358D); + // assert(Goldilocks::toU64(result[3]) == 0XE780D721F698E1E6); + + delete[] x; + delete[] result; + // Rate = time to process 1 posseidon per core + // BytesProcessed = total bytes processed per second on every iteration + int threads_core = 2 * state.range(0) / omp_get_max_threads(); // we assume hyperthreading + state.counters["Rate"] = benchmark::Counter(threads_core * (double)NUM_HASHES / (double)state.range(0), benchmark::Counter::kIsIterationInvariantRate | benchmark::Counter::kInvert); + state.counters["BytesProcessed"] = benchmark::Counter(input_size * sizeof(uint64_t), benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1024); +} + static void POSEIDON_BENCH(benchmark::State &state) { uint64_t input_size = (uint64_t)NUM_HASHES * (uint64_t)SPONGE_WIDTH; @@ -1133,6 +1279,16 @@ BENCHMARK(INV_OP_BENCH) ->Unit(benchmark::kMicrosecond) ->UseRealTime(); +BENCHMARK(POSEIDON2_BENCH_FULL) + ->Unit(benchmark::kMicrosecond) + ->DenseRange(omp_get_max_threads() / 2, omp_get_max_threads(), omp_get_max_threads() / 2) + ->UseRealTime(); + +BENCHMARK(POSEIDON2_BENCH_FULL_AVX) + ->Unit(benchmark::kMicrosecond) + ->DenseRange(omp_get_max_threads() / 2, omp_get_max_threads(), omp_get_max_threads() / 2) + ->UseRealTime(); + BENCHMARK(POSEIDON_BENCH_FULL) ->Unit(benchmark::kMicrosecond) ->DenseRange(omp_get_max_threads() / 2, omp_get_max_threads(), omp_get_max_threads() / 2) @@ -1150,6 +1306,16 @@ BENCHMARK(POSEIDON_BENCH_FULL_AVX512) ->UseRealTime(); #endif +BENCHMARK(POSEIDON2_BENCH) + ->Unit(benchmark::kMicrosecond) + ->DenseRange(omp_get_max_threads() / 2, omp_get_max_threads(), omp_get_max_threads() / 2) + ->UseRealTime(); + +BENCHMARK(POSEIDON2_BENCH_AVX) + ->Unit(benchmark::kMicrosecond) + ->DenseRange(omp_get_max_threads() / 2, omp_get_max_threads(), omp_get_max_threads() / 2) + ->UseRealTime(); + BENCHMARK(POSEIDON_BENCH) ->Unit(benchmark::kMicrosecond) ->DenseRange(omp_get_max_threads() / 2, omp_get_max_threads(), omp_get_max_threads() / 2) @@ -1272,4 +1438,4 @@ BENCHMARK_MAIN(); // icpx -std=c++17 -Wall -march=native -O3 -qopenmp -qopenmp-simd -mavx512f -mavx2 -axCORE-AVX512,CORE-AVX2 -ipo -qopt-zmm-usage=high benchs/bench.cpp src/*.cpp -lbenchmark -lgmp -o bench -D__AVX512__ // RUN: -// ./bench --benchmark_filter=POSEIDON +// ./bench --benchmark_filter=POSEIDON \ No newline at end of file diff --git a/pil2-stark/src/goldilocks/src/poseidon2_goldilocks.cpp b/pil2-stark/src/goldilocks/src/poseidon2_goldilocks.cpp new file mode 100644 index 000000000..285135320 --- /dev/null +++ b/pil2-stark/src/goldilocks/src/poseidon2_goldilocks.cpp @@ -0,0 +1,582 @@ +#include "poseidon2_goldilocks.hpp" +#include /* floor */ +#include "merklehash_goldilocks.hpp" + +void Poseidon2Goldilocks::hash_full_result_seq(Goldilocks::Element *state, const Goldilocks::Element *input) +{ + const int length = SPONGE_WIDTH * sizeof(Goldilocks::Element); + std::memcpy(state, input, length); + + matmul_external_(state); + + for (int r = 0; r < HALF_N_FULL_ROUNDS; r++) + { + pow7add_(state, &(Poseidon2GoldilocksConstants::C[r * SPONGE_WIDTH])); + matmul_external_(state); + } + + for (int r = 0; r < N_PARTIAL_ROUNDS; r++) + { + state[0] = state[0] + Poseidon2GoldilocksConstants::C[HALF_N_FULL_ROUNDS * SPONGE_WIDTH + r]; + pow7(state[0]); + Goldilocks::Element sum_ = Goldilocks::zero(); + add_(sum_, state); + prodadd_(state, Poseidon2GoldilocksConstants::D, sum_); + } + + for (int r = 0; r < HALF_N_FULL_ROUNDS; r++) + { + pow7add_(state, &(Poseidon2GoldilocksConstants::C[HALF_N_FULL_ROUNDS * SPONGE_WIDTH + N_PARTIAL_ROUNDS + r * SPONGE_WIDTH])); + matmul_external_(state); + } +} +void Poseidon2Goldilocks::linear_hash_seq(Goldilocks::Element *output, Goldilocks::Element *input, uint64_t size) +{ + uint64_t remaining = size; + Goldilocks::Element state[SPONGE_WIDTH]; + + if (size <= CAPACITY) + { + std::memcpy(output, input, size * sizeof(Goldilocks::Element)); + std::memset(&output[size], 0, (CAPACITY - size) * sizeof(Goldilocks::Element)); + return; // no need to hash + } + while (remaining) + { + if (remaining == size) + { + memset(state + RATE, 0, CAPACITY * sizeof(Goldilocks::Element)); + } + else + { + std::memcpy(state + RATE, state, CAPACITY * sizeof(Goldilocks::Element)); + } + + uint64_t n = (remaining < RATE) ? remaining : RATE; + memset(&state[n], 0, (RATE - n) * sizeof(Goldilocks::Element)); + std::memcpy(state, input + (size - remaining), n * sizeof(Goldilocks::Element)); + hash_full_result_seq(state, state); + remaining -= n; + } + if (size > 0) + { + std::memcpy(output, state, CAPACITY * sizeof(Goldilocks::Element)); + } + else + { + memset(output, 0, CAPACITY * sizeof(Goldilocks::Element)); + } +} +void Poseidon2Goldilocks::merkletree_seq(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, int nThreads, uint64_t dim) +{ + if (num_rows == 0) + { + return; + } + + Goldilocks::Element *cursor = tree; + // memset(cursor, 0, num_rows * CAPACITY * sizeof(Goldilocks::Element)); + if (nThreads == 0) + nThreads = omp_get_max_threads(); + +#pragma omp parallel for num_threads(nThreads) + for (uint64_t i = 0; i < num_rows; i++) + { + linear_hash_seq(&cursor[i * CAPACITY], &input[i * num_cols * dim], num_cols * dim); + } + + // Build the merkle tree + uint64_t pending = num_rows; + uint64_t nextN = floor((pending - 1) / 2) + 1; + uint64_t nextIndex = 0; + + while (pending > 1) + { +#pragma omp parallel for num_threads(nThreads) + for (uint64_t i = 0; i < nextN; i++) + { + Goldilocks::Element pol_input[SPONGE_WIDTH]; + memset(pol_input, 0, SPONGE_WIDTH * sizeof(Goldilocks::Element)); + std::memcpy(pol_input, &cursor[nextIndex + i * RATE], RATE * sizeof(Goldilocks::Element)); + hash_seq((Goldilocks::Element(&)[CAPACITY])cursor[nextIndex + (pending + i) * CAPACITY], pol_input); + } + nextIndex += pending * CAPACITY; + pending = pending / 2; + nextN = floor((pending - 1) / 2) + 1; + } +} +void Poseidon2Goldilocks::merkletree_batch_seq(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, uint64_t batch_size, int nThreads, uint64_t dim) +{ + if (num_rows == 0) + { + return; + } + + Goldilocks::Element *cursor = tree; + uint64_t nbatches = 1; + if (num_cols > 0) + { + nbatches = (num_cols + batch_size - 1) / batch_size; + } + uint64_t nlastb = num_cols - (nbatches - 1) * batch_size; + + if (nThreads == 0) + nThreads = omp_get_max_threads(); + +#pragma omp parallel for num_threads(nThreads) + for (uint64_t i = 0; i < num_rows; i++) + { + Goldilocks::Element buff0[nbatches * CAPACITY]; + for (uint64_t j = 0; j < nbatches; j++) + { + uint64_t nn = batch_size; + if (j == nbatches - 1) + nn = nlastb; + linear_hash_seq(&buff0[j * CAPACITY], &input[i * num_cols * dim + j * batch_size * dim], nn * dim); + } + linear_hash_seq(&cursor[i * CAPACITY], buff0, nbatches * CAPACITY); + } + + // Build the merkle tree + uint64_t pending = num_rows; + uint64_t nextN = floor((pending - 1) / 2) + 1; + uint64_t nextIndex = 0; + + while (pending > 1) + { +#pragma omp parallel for num_threads(nThreads) + for (uint64_t i = 0; i < nextN; i++) + { + Goldilocks::Element pol_input[SPONGE_WIDTH]; + memset(pol_input, 0, SPONGE_WIDTH * sizeof(Goldilocks::Element)); + std::memcpy(pol_input, &cursor[nextIndex + i * RATE], RATE * sizeof(Goldilocks::Element)); + hash_seq((Goldilocks::Element(&)[CAPACITY])cursor[nextIndex + (pending + i) * CAPACITY], pol_input); + } + nextIndex += pending * CAPACITY; + pending = pending / 2; + nextN = floor((pending - 1) / 2) + 1; + } +} + + +void Poseidon2Goldilocks::hash_full_result(Goldilocks::Element *state, const Goldilocks::Element *input) +{ + const int length = SPONGE_WIDTH * sizeof(Goldilocks::Element); + std::memcpy(state, input, length); + __m256i st0, st1, st2; + Goldilocks::load_avx(st0, &(state[0])); + Goldilocks::load_avx(st1, &(state[4])); + Goldilocks::load_avx(st2, &(state[8])); + + matmul_external_avx(st0, st1, st2); + + for (int r = 0; r < HALF_N_FULL_ROUNDS; r++) + { + add_avx_small(st0, st1, st2, &(Poseidon2GoldilocksConstants::C[r * SPONGE_WIDTH])); + pow7_avx(st0, st1, st2); + matmul_external_avx(st0, st1, st2); + } + + Goldilocks::store_avx(&(state[0]), st0); + Goldilocks::Element state0_ = state[0]; + + __m256i d0, d1, d2; + Goldilocks::load_avx(d0, &(Poseidon2GoldilocksConstants::D[0])); + Goldilocks::load_avx(d1, &(Poseidon2GoldilocksConstants::D[4])); + Goldilocks::load_avx(d2, &(Poseidon2GoldilocksConstants::D[8])); + + __m256i part_sum; + Goldilocks::Element partial_sum[4]; + Goldilocks::Element aux = state0_; + for (int r = 0; r < N_PARTIAL_ROUNDS; r++) + { + Goldilocks::add_avx(part_sum, st1, st2); + Goldilocks::add_avx(part_sum, part_sum, st0); + Goldilocks::store_avx(partial_sum, part_sum); + Goldilocks::Element sum = partial_sum[0] + partial_sum[1] + partial_sum[2] + partial_sum[3]; + sum = sum - aux; + + state0_ = state0_ + Poseidon2GoldilocksConstants::C[HALF_N_FULL_ROUNDS * SPONGE_WIDTH + r]; + pow7(state0_); + + sum = sum + state0_; + + __m256i scalar1 = _mm256_set1_epi64x(sum.fe); + Goldilocks::mult_avx(st0, st0, d0); + Goldilocks::mult_avx(st1, st1, d1); + Goldilocks::mult_avx(st2, st2, d2); + Goldilocks::add_avx(st0, st0, scalar1); + Goldilocks::add_avx(st1, st1, scalar1); + Goldilocks::add_avx(st2, st2, scalar1); + state0_ = state0_ * Poseidon2GoldilocksConstants::D[0] + sum; + aux = aux * Poseidon2GoldilocksConstants::D[0] + sum; + } + + Goldilocks::store_avx(&(state[0]), st0); + state[0] = state0_; + Goldilocks::load_avx(st0, &(state[0])); + + for (int r = 0; r < HALF_N_FULL_ROUNDS; r++) + { + add_avx_small(st0, st1, st2, &(Poseidon2GoldilocksConstants::C[HALF_N_FULL_ROUNDS * SPONGE_WIDTH + N_PARTIAL_ROUNDS + r * SPONGE_WIDTH])); + pow7_avx(st0, st1, st2); + + matmul_external_avx(st0, st1, st2); + } + + Goldilocks::store_avx(&(state[0]), st0); + Goldilocks::store_avx(&(state[4]), st1); + Goldilocks::store_avx(&(state[8]), st2); +} + +void Poseidon2Goldilocks::linear_hash(Goldilocks::Element *output, Goldilocks::Element *input, uint64_t size) +{ + uint64_t remaining = size; + Goldilocks::Element state[SPONGE_WIDTH]; + + if (size <= CAPACITY) + { + std::memcpy(output, input, size * sizeof(Goldilocks::Element)); + std::memset(&output[size], 0, (CAPACITY - size) * sizeof(Goldilocks::Element)); + return; // no need to hash + } + while (remaining) + { + if (remaining == size) + { + memset(state + RATE, 0, CAPACITY * sizeof(Goldilocks::Element)); + } + else + { + std::memcpy(state + RATE, state, CAPACITY * sizeof(Goldilocks::Element)); + } + + uint64_t n = (remaining < RATE) ? remaining : RATE; + memset(&state[n], 0, (RATE - n) * sizeof(Goldilocks::Element)); + std::memcpy(state, input + (size - remaining), n * sizeof(Goldilocks::Element)); + hash_full_result(state, state); + remaining -= n; + } + if (size > 0) + { + std::memcpy(output, state, CAPACITY * sizeof(Goldilocks::Element)); + } + else + { + memset(output, 0, CAPACITY * sizeof(Goldilocks::Element)); + } +} +void Poseidon2Goldilocks::merkletree_avx(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, int nThreads, uint64_t dim) +{ + if (num_rows == 0) + { + return; + } + Goldilocks::Element *cursor = tree; + // memset(cursor, 0, num_rows * CAPACITY * sizeof(Goldilocks::Element)); + if (nThreads == 0) + nThreads = omp_get_max_threads(); + +#pragma omp parallel for num_threads(nThreads) + for (uint64_t i = 0; i < num_rows; i++) + { + linear_hash(&cursor[i * CAPACITY], &input[i * num_cols * dim], num_cols * dim); + } + + // Build the merkle tree + uint64_t pending = num_rows; + uint64_t nextN = floor((pending - 1) / 2) + 1; + uint64_t nextIndex = 0; + while (pending > 1) + { +#pragma omp parallel for num_threads(nThreads) + for (uint64_t i = 0; i < nextN; i++) + { + Goldilocks::Element pol_input[SPONGE_WIDTH]; + memset(pol_input, 0, SPONGE_WIDTH * sizeof(Goldilocks::Element)); + std::memcpy(pol_input, &cursor[nextIndex + i * RATE], RATE * sizeof(Goldilocks::Element)); + hash((Goldilocks::Element(&)[CAPACITY])cursor[nextIndex + (pending + i) * CAPACITY], pol_input); + } + nextIndex += pending * CAPACITY; + pending = pending / 2; + nextN = floor((pending - 1) / 2) + 1; + } +} +void Poseidon2Goldilocks::merkletree_batch_avx(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, uint64_t batch_size, int nThreads, uint64_t dim) +{ + if (num_rows == 0) + { + return; + } + Goldilocks::Element *cursor = tree; + uint64_t nbatches = 1; + if (num_cols > 0) + { + nbatches = (num_cols + batch_size - 1) / batch_size; + } + uint64_t nlastb = num_cols - (nbatches - 1) * batch_size; + + if (nThreads == 0) + nThreads = omp_get_max_threads(); + +#pragma omp parallel for num_threads(nThreads) + for (uint64_t i = 0; i < num_rows; i++) + { + Goldilocks::Element buff0[nbatches * CAPACITY]; + for (uint64_t j = 0; j < nbatches; j++) + { + uint64_t nn = batch_size; + if (j == nbatches - 1) + nn = nlastb; + linear_hash(&buff0[j * CAPACITY], &input[i * num_cols * dim + j * batch_size * dim], nn * dim); + } + linear_hash(&cursor[i * CAPACITY], buff0, nbatches * CAPACITY); + } + + // Build the merkle tree + uint64_t pending = num_rows; + uint64_t nextN = floor((pending - 1) / 2) + 1; + uint64_t nextIndex = 0; + + while (pending > 1) + { +#pragma omp parallel for num_threads(nThreads) + for (uint64_t i = 0; i < nextN; i++) + { + Goldilocks::Element pol_input[SPONGE_WIDTH]; + memset(pol_input, 0, SPONGE_WIDTH * sizeof(Goldilocks::Element)); + std::memcpy(pol_input, &cursor[nextIndex + i * RATE], RATE * sizeof(Goldilocks::Element)); + hash((Goldilocks::Element(&)[CAPACITY])cursor[nextIndex + (pending + i) * CAPACITY], pol_input); + } + nextIndex += pending * CAPACITY; + pending = pending / 2; + nextN = floor((pending - 1) / 2) + 1; + } +} + +/* +#ifdef __AVX512__ +void Poseidon2Goldilocks::hash_full_result_avx512(Goldilocks::Element *state, const Goldilocks::Element *input) +{ + + const int length = 2 * SPONGE_WIDTH * sizeof(Goldilocks::Element); + std::memcpy(state, input, length); + __m512i st0, st1, st2; + Goldilocks::load_avx512(st0, &(state[0])); + Goldilocks::load_avx512(st1, &(state[8])); + Goldilocks::load_avx512(st2, &(state[16])); + add_avx512_small(st0, st1, st2, &(Poseidon2GoldilocksConstants::C[0])); + + for (int r = 0; r < HALF_N_FULL_ROUNDS - 1; r++) + { + pow7_avx512(st0, st1, st2); + add_avx512_small(st0, st1, st2, &(Poseidon2GoldilocksConstants::C[(r + 1) * SPONGE_WIDTH])); // rick + Goldilocks::mmult_avx512_8(st0, st1, st2, &(Poseidon2GoldilocksConstants::M_[0])); + } + pow7_avx512(st0, st1, st2); + add_avx512(st0, st1, st2, &(Poseidon2GoldilocksConstants::C[(HALF_N_FULL_ROUNDS * SPONGE_WIDTH)])); + Goldilocks::mmult_avx512(st0, st1, st2, &(Poseidon2GoldilocksConstants::P_[0])); + + Goldilocks::store_avx512(&(state[0]), st0); + Goldilocks::Element s04_[2] = {state[0], state[4]}; + Goldilocks::Element s04[2]; + + __m512i mask = _mm512_set_epi64(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0); // rick, not better to define where u use it? + for (int r = 0; r < N_PARTIAL_ROUNDS; r++) + { + s04[0] = s04_[0]; + s04[1] = s04_[1]; + pow7(s04[0]); + pow7(s04[1]); + s04[0] = s04[0] + Poseidon2GoldilocksConstants::C[(HALF_N_FULL_ROUNDS + 1) * SPONGE_WIDTH + r]; + s04[1] = s04[1] + Poseidon2GoldilocksConstants::C[(HALF_N_FULL_ROUNDS + 1) * SPONGE_WIDTH + r]; + s04_[0] = s04[0] * Poseidon2GoldilocksConstants::S[(SPONGE_WIDTH * 2 - 1) * r]; + s04_[1] = s04[1] * Poseidon2GoldilocksConstants::S[(SPONGE_WIDTH * 2 - 1) * r]; + st0 = _mm512_and_si512(st0, mask); // rick, do we need a new one? + Goldilocks::Element aux[2]; + Goldilocks::dot_avx512(aux, st0, st1, st2, &(Poseidon2GoldilocksConstants::S[(SPONGE_WIDTH * 2 - 1) * r])); + s04_[0] = s04_[0] + aux[0]; + s04_[1] = s04_[1] + aux[1]; + __m512i scalar1 = _mm512_set_epi64(s04[1].fe, s04[1].fe, s04[1].fe, s04[1].fe, s04[0].fe, s04[0].fe, s04[0].fe, s04[0].fe); + __m512i w0, w1, w2; + + const Goldilocks::Element *auxS = &(Poseidon2GoldilocksConstants::S[(SPONGE_WIDTH * 2 - 1) * r + SPONGE_WIDTH - 1]); + __m512i s0 = _mm512_set4_epi64(auxS[3].fe, auxS[2].fe, auxS[1].fe, auxS[0].fe); + __m512i s1 = _mm512_set4_epi64(auxS[7].fe, auxS[6].fe, auxS[5].fe, auxS[4].fe); + __m512i s2 = _mm512_set4_epi64(auxS[11].fe, auxS[10].fe, auxS[9].fe, auxS[8].fe); + + Goldilocks::mult_avx512(w0, scalar1, s0); + Goldilocks::mult_avx512(w1, scalar1, s1); + Goldilocks::mult_avx512(w2, scalar1, s2); + Goldilocks::add_avx512(st0, st0, w0); + Goldilocks::add_avx512(st1, st1, w1); + Goldilocks::add_avx512(st2, st2, w2); + s04[0] = s04[0] + Poseidon2GoldilocksConstants::S[(SPONGE_WIDTH * 2 - 1) * r + SPONGE_WIDTH - 1]; + s04[1] = s04[1] + Poseidon2GoldilocksConstants::S[(SPONGE_WIDTH * 2 - 1) * r + SPONGE_WIDTH - 1]; + } + + Goldilocks::store_avx512(&(state[0]), st0); + state[0] = s04_[0]; + state[4] = s04_[1]; + Goldilocks::load_avx512(st0, &(state[0])); + + for (int r = 0; r < HALF_N_FULL_ROUNDS - 1; r++) + { + pow7_avx512(st0, st1, st2); + add_avx512_small(st0, st1, st2, &(Poseidon2GoldilocksConstants::C[(HALF_N_FULL_ROUNDS + 1) * SPONGE_WIDTH + N_PARTIAL_ROUNDS + r * SPONGE_WIDTH])); + Goldilocks::mmult_avx512_8(st0, st1, st2, &(Poseidon2GoldilocksConstants::M_[0])); + } + pow7_avx512(st0, st1, st2); + Goldilocks::mmult_avx512_8(st0, st1, st2, &(Poseidon2GoldilocksConstants::M_[0])); + + Goldilocks::store_avx512(&(state[0]), st0); + Goldilocks::store_avx512(&(state[8]), st1); + Goldilocks::store_avx512(&(state[16]), st2); +} +void Poseidon2Goldilocks::linear_hash_avx512(Goldilocks::Element *output, Goldilocks::Element *input, uint64_t size) +{ + uint64_t remaining = size; + Goldilocks::Element state[2 * SPONGE_WIDTH]; + + if (size <= CAPACITY) + { + std::memcpy(output, input, size * sizeof(Goldilocks::Element)); + std::memset(output + size, 0, (CAPACITY - size) * sizeof(Goldilocks::Element)); + std::memcpy(output + CAPACITY, input + size, size * sizeof(Goldilocks::Element)); + std::memset(output + CAPACITY + size, 0, (CAPACITY - size) * sizeof(Goldilocks::Element)); + return; // no need to hash + } + while (remaining) + { + if (remaining == size) + { + memset(state + 2 * RATE, 0, 2 * CAPACITY * sizeof(Goldilocks::Element)); + } + else + { + std::memcpy(state + 2 * RATE, state, 2 * CAPACITY * sizeof(Goldilocks::Element)); + } + + uint64_t n = (remaining < RATE) ? remaining : RATE; + memset(state, 0, 2 * RATE * sizeof(Goldilocks::Element)); + + if (n <= 4) + { + std::memcpy(state, input + (size - remaining), n * sizeof(Goldilocks::Element)); + std::memcpy(state + 4, input + size + (size - remaining), n * sizeof(Goldilocks::Element)); + } + else + { + std::memcpy(state, input + (size - remaining), 4 * sizeof(Goldilocks::Element)); + std::memcpy(state + 4, input + size + (size - remaining), 4 * sizeof(Goldilocks::Element)); + std::memcpy(state + 8, input + (size - remaining) + 4, (n - 4) * sizeof(Goldilocks::Element)); + std::memcpy(state + 12, input + size + (size - remaining) + 4, (n - 4) * sizeof(Goldilocks::Element)); + } + + hash_full_result_avx512(state, state); + remaining -= n; + } + if (size > 0) + { + std::memcpy(output, state, 2 * CAPACITY * sizeof(Goldilocks::Element)); + } + else + { + memset(output, 0, 2 * CAPACITY * sizeof(Goldilocks::Element)); + } +} +void Poseidon2Goldilocks::merkletree_avx512(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, int nThreads, uint64_t dim) +{ + if (num_rows == 0) + { + return; + } + Goldilocks::Element *cursor = tree; + // memset(cursor, 0, num_rows * CAPACITY * sizeof(Goldilocks::Element)); + if (nThreads == 0) + nThreads = omp_get_max_threads(); + +#pragma omp parallel for num_threads(nThreads) + for (uint64_t i = 0; i < num_rows; i += 2) + { + linear_hash_avx512(&cursor[i * CAPACITY], &input[i * num_cols * dim], num_cols * dim); + } + + // Build the merkle tree + uint64_t pending = num_rows; + uint64_t nextN = floor((pending - 1) / 2) + 1; + uint64_t nextIndex = 0; + + while (pending > 1) + { +#pragma omp parallel for num_threads(nThreads) + for (uint64_t i = 0; i < nextN; i++) + { + Goldilocks::Element pol_input[SPONGE_WIDTH]; + memset(pol_input, 0, SPONGE_WIDTH * sizeof(Goldilocks::Element)); + std::memcpy(pol_input, &cursor[nextIndex + i * RATE], RATE * sizeof(Goldilocks::Element)); + hash((Goldilocks::Element(&)[CAPACITY])cursor[nextIndex + (pending + i) * CAPACITY], pol_input); + } + nextIndex += pending * CAPACITY; + pending = pending / 2; + nextN = floor((pending - 1) / 2) + 1; + } +} +void Poseidon2Goldilocks::merkletree_batch_avx512(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, uint64_t batch_size, int nThreads, uint64_t dim) +{ + if (num_rows == 0) + { + return; + } + Goldilocks::Element *cursor = tree; + uint64_t nbatches = 1; + if (num_cols > 0) + { + nbatches = (num_cols + batch_size - 1) / batch_size; + } + uint64_t nlastb = num_cols - (nbatches - 1) * batch_size; + + if (nThreads == 0) + nThreads = omp_get_max_threads(); + +#pragma omp parallel for num_threads(nThreads) + for (uint64_t i = 0; i < num_rows; i += 2) + { + Goldilocks::Element buff0[2 * nbatches * CAPACITY]; + for (uint64_t j = 0; j < nbatches; ++j) + { + uint64_t nn = batch_size; + if (j == nbatches - 1) + nn = nlastb; + Goldilocks::Element buff1[2 * nn * dim]; + Goldilocks::Element buff2[2 * CAPACITY]; + std::memcpy(&buff1[0], &input[i * num_cols * dim + j * batch_size * dim], dim * nn * sizeof(Goldilocks::Element)); + std::memcpy(&buff1[nn * dim], &input[(i + 1) * num_cols * dim + j * batch_size * dim], dim * nn * sizeof(Goldilocks::Element)); + linear_hash_avx512(buff2, buff1, nn * dim); + memcpy(&buff0[j * CAPACITY], buff2, CAPACITY * sizeof(Goldilocks::Element)); + memcpy(&buff0[(j + nbatches) * CAPACITY], &buff2[CAPACITY], CAPACITY * sizeof(Goldilocks::Element)); + } + linear_hash_avx512(&cursor[i * CAPACITY], buff0, nbatches * CAPACITY); + } + + // Build the merkle tree + uint64_t pending = num_rows; + uint64_t nextN = floor((pending - 1) / 2) + 1; + uint64_t nextIndex = 0; + + while (pending > 1) + { +#pragma omp parallel for num_threads(nThreads) + for (uint64_t i = 0; i < nextN; i++) + { + Goldilocks::Element pol_input[SPONGE_WIDTH]; + memset(pol_input, 0, SPONGE_WIDTH * sizeof(Goldilocks::Element)); + std::memcpy(pol_input, &cursor[nextIndex + i * RATE], RATE * sizeof(Goldilocks::Element)); + hash((Goldilocks::Element(&)[CAPACITY])cursor[nextIndex + (pending + i) * CAPACITY], pol_input); + } + nextIndex += pending * CAPACITY; + pending = pending / 2; + nextN = floor((pending - 1) / 2) + 1; + } +} +#endif */ \ No newline at end of file diff --git a/pil2-stark/src/goldilocks/src/poseidon2_goldilocks.hpp b/pil2-stark/src/goldilocks/src/poseidon2_goldilocks.hpp new file mode 100644 index 000000000..98ef9baf5 --- /dev/null +++ b/pil2-stark/src/goldilocks/src/poseidon2_goldilocks.hpp @@ -0,0 +1,199 @@ +#ifndef POSEIDON2_GOLDILOCKS +#define POSEIDON2_GOLDILOCKS + +#include "poseidon2_goldilocks_constants.hpp" +#include "goldilocks_base_field.hpp" +#include + +#define RATE 8 +#define CAPACITY 4 +#define HASH_SIZE 4 +#define SPONGE_WIDTH (RATE + CAPACITY) +#define HALF_N_FULL_ROUNDS 4 +#define N_FULL_ROUNDS_TOTAL (2 * HALF_N_FULL_ROUNDS) +#define N_PARTIAL_ROUNDS 22 +#define N_ROUNDS (N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS) + +class Poseidon2Goldilocks +{ +private: + inline void static pow7(Goldilocks::Element &x); + inline void static pow7_(Goldilocks::Element *x); + inline void static add_(Goldilocks::Element &x, const Goldilocks::Element *st); + inline void static pow7add_(Goldilocks::Element *x, const Goldilocks::Element C[SPONGE_WIDTH]); + inline void static prodadd_(Goldilocks::Element *x, const Goldilocks::Element D[SPONGE_WIDTH], const Goldilocks::Element &sum); + inline void static matmul_m4_(Goldilocks::Element *x); + inline void static matmul_external_(Goldilocks::Element *x); + + inline void static add_avx(__m256i &st0, __m256i &st1, __m256i &st2, const Goldilocks::Element C[SPONGE_WIDTH]); + inline void static pow7_avx(__m256i &st0, __m256i &st1, __m256i &st2); + inline void static add_avx_a(__m256i &st0, __m256i &st1, __m256i &st2, const Goldilocks::Element C[SPONGE_WIDTH]); + inline void static add_avx_small(__m256i &st0, __m256i &st1, __m256i &st2, const Goldilocks::Element C[SPONGE_WIDTH]); + inline void static matmul_external_avx(__m256i &st0, __m256i &st1, __m256i &st2); +#ifdef __AVX512__ + inline void static pow7_avx512(__m512i &st0, __m512i &st1, __m512i &st2); + inline void static add_avx512(__m512i &st0, __m512i &st1, __m512i &st2, const Goldilocks::Element C[SPONGE_WIDTH]); + inline void static add_avx512_a(__m512i &st0, __m512i &st1, __m512i &st2, const Goldilocks::Element C[SPONGE_WIDTH]); + inline void static add_avx512_small(__m512i &st0, __m512i &st1, __m512i &st2, const Goldilocks::Element C[SPONGE_WIDTH]); +#endif + +public: + // Wrapper: + void static merkletree(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, int nThreads = 0, uint64_t dim = 1); + void static merkletree_batch(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, uint64_t batch_size, int nThreads = 0, uint64_t dim = 1); + + // Non-vectorized: + void static hash_full_result_seq(Goldilocks::Element *, const Goldilocks::Element *); + void static linear_hash_seq(Goldilocks::Element *output, Goldilocks::Element *input, uint64_t size); + void static merkletree_seq(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, int nThreads = 0, uint64_t dim = 1); + void static hash_seq(Goldilocks::Element (&state)[CAPACITY], const Goldilocks::Element (&input)[SPONGE_WIDTH]); + void static merkletree_batch_seq(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, uint64_t batch_size, int nThreads = 0, uint64_t dim = 1); + + // Vectorized AVX: + // Note, the functions that do not have the _avx suffix are the default ones to + // be used in the prover, they implement avx vectorixation though. + void static hash_full_result(Goldilocks::Element *, const Goldilocks::Element *); + void static hash(Goldilocks::Element (&state)[CAPACITY], const Goldilocks::Element (&input)[SPONGE_WIDTH]); + void static linear_hash(Goldilocks::Element *output, Goldilocks::Element *input, uint64_t size); + void static merkletree_avx(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, int nThreads = 0, uint64_t dim = 1); + void static merkletree_batch_avx(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, uint64_t batch_size, int nThreads = 0, uint64_t dim = 1); + +#ifdef __AVX512__ + // Vectorized AVX512: + void static hash_full_result_avx512(Goldilocks::Element *, const Goldilocks::Element *); + void static hash_avx512(Goldilocks::Element (&state)[2 * CAPACITY], const Goldilocks::Element (&input)[2 * SPONGE_WIDTH]); + void static linear_hash_avx512(Goldilocks::Element *output, Goldilocks::Element *input, uint64_t size); + void static merkletree_avx512(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, int nThreads = 0, uint64_t dim = 1); + void static merkletree_batch_avx512(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, uint64_t batch_size, int nThreads = 0, uint64_t dim = 1); +#endif + +#ifdef __USE_CUDA__ + void static merkletree_cuda(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, int nThreads = 0, uint64_t dim = 1); + void static merkletree_cuda_gpudata(Goldilocks::Element *tree, uint64_t *gpu_input, uint64_t num_cols, uint64_t num_rows, int nThreads = 0, uint64_t dim = 1); + void static partial_hash_init_gpu(uint64_t **state, uint32_t num_rows, uint32_t ngpus); + void static partial_hash_gpu(uint64_t *input, uint32_t num_cols, uint32_t num_rows, uint64_t *state); + void static merkletree_cuda_multi_gpu_full(Goldilocks::Element *tree, uint64_t** gpu_inputs, uint64_t** gpu_trees, void* gpu_streams, uint64_t num_cols, uint64_t num_rows, uint64_t num_rows_device, uint32_t const ngpu, uint64_t dim = 1); + void static merkletree_cuda_multi_gpu_steps(uint64_t** gpu_inputs, uint64_t** gpu_trees, void* v_gpu_streams, uint64_t num_cols, uint64_t num_rows_device, uint32_t const ngpu, uint64_t dim = 1); + void static merkletree_cuda_multi_gpu_final(Goldilocks::Element *tree, uint64_t* final_tree, void* v_gpu_streams, uint64_t num_rows); + + void static merkletree_cuda_async(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows); + +#endif +}; + +// WRAPPERS + +inline void Poseidon2Goldilocks::merkletree(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, int nThreads, uint64_t dim) +{ +#ifdef __AVX512__ + merkletree_avx512(tree, input, num_cols, num_rows, nThreads, dim); +#else + merkletree_avx(tree, input, num_cols, num_rows, nThreads, dim); +#endif +} +inline void Poseidon2Goldilocks::merkletree_batch(Goldilocks::Element *tree, Goldilocks::Element *input, uint64_t num_cols, uint64_t num_rows, uint64_t batch_size, int nThreads, uint64_t dim) +{ +#ifdef __AVX512__ + merkletree_batch_avx512(tree, input, num_cols, num_rows, batch_size, nThreads, dim); +#else + merkletree_batch_avx(tree, input, num_cols, num_rows, batch_size, nThreads, dim); +#endif +} + +inline void Poseidon2Goldilocks::pow7(Goldilocks::Element &x) +{ + Goldilocks::Element x2 = x * x; + Goldilocks::Element x3 = x * x2; + Goldilocks::Element x4 = x2 * x2; + x = x3 * x4; +}; +inline void Poseidon2Goldilocks::pow7_(Goldilocks::Element *x) +{ + Goldilocks::Element x2[SPONGE_WIDTH], x3[SPONGE_WIDTH], x4[SPONGE_WIDTH]; + for (int i = 0; i < SPONGE_WIDTH; ++i) + { + x2[i] = x[i] * x[i]; + x3[i] = x[i] * x2[i]; + x4[i] = x2[i] * x2[i]; + x[i] = x3[i] * x4[i]; + } +}; + +inline void Poseidon2Goldilocks::add_(Goldilocks::Element &x, const Goldilocks::Element *st) +{ + for (int i = 0; i < SPONGE_WIDTH; ++i) + { + x = x + st[i]; + } +} +inline void Poseidon2Goldilocks::prodadd_(Goldilocks::Element *x, const Goldilocks::Element D[SPONGE_WIDTH], const Goldilocks::Element &sum) +{ + for (int i = 0; i < SPONGE_WIDTH; ++i) + { + x[i] = x[i]*D[i] + sum; + } +} + +inline void Poseidon2Goldilocks::pow7add_(Goldilocks::Element *x, const Goldilocks::Element C[SPONGE_WIDTH]) +{ + Goldilocks::Element x2[SPONGE_WIDTH], x3[SPONGE_WIDTH], x4[SPONGE_WIDTH]; + + for (int i = 0; i < SPONGE_WIDTH; ++i) + { + Goldilocks::Element xi = x[i] + C[i]; + x2[i] = xi * xi; + x3[i] = xi * x2[i]; + x4[i] = x2[i] * x2[i]; + x[i] = x3[i] * x4[i]; + } +}; + +inline void Poseidon2Goldilocks::matmul_m4_(Goldilocks::Element *x) { + Goldilocks::Element t0 = x[0] + x[1]; + Goldilocks::Element t1 = x[2] + x[3]; + Goldilocks::Element t2 = x[1] + x[1] + t1; + Goldilocks::Element t3 = x[3] + x[3] + t0; + Goldilocks::Element t1_2 = t1 + t1; + Goldilocks::Element t0_2 = t0 + t0; + Goldilocks::Element t4 = t1_2 + t1_2 + t3; + Goldilocks::Element t5 = t0_2 + t0_2 + t2; + Goldilocks::Element t6 = t3 + t5; + Goldilocks::Element t7 = t2 + t4; + + x[0] = t6; + x[1] = t5; + x[2] = t7; + x[3] = t4; +} + +inline void Poseidon2Goldilocks::matmul_external_(Goldilocks::Element *x) { + matmul_m4_(&x[0]); + matmul_m4_(&x[4]); + matmul_m4_(&x[8]); + + Goldilocks::Element stored[4] = { + x[0] + x[4] + x[8], + x[1] + x[5] + x[9], + x[2] + x[6] + x[10], + x[3] + x[7] + x[11], + }; + + for (int i = 0; i < SPONGE_WIDTH; ++i) + { + x[i] = x[i] + stored[i % 4]; + } +} + +inline void Poseidon2Goldilocks::hash_seq(Goldilocks::Element (&state)[CAPACITY], Goldilocks::Element const (&input)[SPONGE_WIDTH]) +{ + Goldilocks::Element aux[SPONGE_WIDTH]; + hash_full_result_seq(aux, input); + std::memcpy(state, aux, CAPACITY * sizeof(Goldilocks::Element)); +} + +#include "poseidon2_goldilocks_avx.hpp" + +#ifdef __AVX512__ +#include "poseidon2_goldilocks_avx512.hpp" +#endif +#endif \ No newline at end of file diff --git a/pil2-stark/src/goldilocks/src/poseidon2_goldilocks_avx.hpp b/pil2-stark/src/goldilocks/src/poseidon2_goldilocks_avx.hpp new file mode 100644 index 000000000..909dec937 --- /dev/null +++ b/pil2-stark/src/goldilocks/src/poseidon2_goldilocks_avx.hpp @@ -0,0 +1,118 @@ +#ifndef POSEIDON2_GOLDILOCKS_AVX +#define POSEIDON2_GOLDILOCKS_AVX + +#include "poseidon2_goldilocks.hpp" +#include "goldilocks_base_field.hpp" +#include + +const __m256i zero = _mm256_setzero_si256(); + +inline void Poseidon2Goldilocks::hash(Goldilocks::Element (&state)[CAPACITY], Goldilocks::Element const (&input)[SPONGE_WIDTH]) +{ + Goldilocks::Element aux[SPONGE_WIDTH]; + hash_full_result(aux, input); + std::memcpy(state, aux, CAPACITY * sizeof(Goldilocks::Element)); +} + +inline void Poseidon2Goldilocks::matmul_external_avx(__m256i &st0, __m256i &st1, __m256i &st2) +{ + + __m256i t0_ = _mm256_permute2f128_si256(st0, st2, 0b00100000); + __m256i t1_ = _mm256_permute2f128_si256(st1, zero, 0b00100000); + __m256i t2_ = _mm256_permute2f128_si256(st0, st2, 0b00110001); + __m256i t3_ = _mm256_permute2f128_si256(st1, zero, 0b00110001); + __m256i c0 = _mm256_castpd_si256(_mm256_unpacklo_pd(_mm256_castsi256_pd(t0_), _mm256_castsi256_pd(t1_))); + __m256i c1 = _mm256_castpd_si256(_mm256_unpackhi_pd(_mm256_castsi256_pd(t0_), _mm256_castsi256_pd(t1_))); + __m256i c2 = _mm256_castpd_si256(_mm256_unpacklo_pd(_mm256_castsi256_pd(t2_), _mm256_castsi256_pd(t3_))); + __m256i c3 = _mm256_castpd_si256(_mm256_unpackhi_pd(_mm256_castsi256_pd(t2_), _mm256_castsi256_pd(t3_))); + + __m256i t0, t0_2, t1, t1_2, t2, t3, t4, t5, t6, t7; + Goldilocks::add_avx(t0, c0, c1); + Goldilocks::add_avx(t1, c2, c3); + Goldilocks::add_avx(t2, c1, c1); + Goldilocks::add_avx(t2, t2, t1); + Goldilocks::add_avx(t3, c3, c3); + Goldilocks::add_avx(t3, t3, t0); + Goldilocks::add_avx(t1_2, t1, t1); + Goldilocks::add_avx(t0_2, t0, t0); + Goldilocks::add_avx(t4, t1_2, t1_2); + Goldilocks::add_avx(t4, t4, t3); + Goldilocks::add_avx(t5, t0_2, t0_2); + Goldilocks::add_avx(t5, t5, t2); + Goldilocks::add_avx(t6, t3, t5); + Goldilocks::add_avx(t7, t2, t4); + + // Step 1: Reverse unpacking + t0_ = _mm256_castpd_si256(_mm256_unpacklo_pd(_mm256_castsi256_pd(t6), _mm256_castsi256_pd(t5))); + t1_ = _mm256_castpd_si256(_mm256_unpackhi_pd(_mm256_castsi256_pd(t6), _mm256_castsi256_pd(t5))); + t2_ = _mm256_castpd_si256(_mm256_unpacklo_pd(_mm256_castsi256_pd(t7), _mm256_castsi256_pd(t4))); + t3_ = _mm256_castpd_si256(_mm256_unpackhi_pd(_mm256_castsi256_pd(t7), _mm256_castsi256_pd(t4))); + + // Step 2: Reverse _mm256_permute2f128_si256 + st0 = _mm256_permute2f128_si256(t0_, t2_, 0b00100000); // Combine low halves + st2 = _mm256_permute2f128_si256(t0_, t2_, 0b00110001); // Combine high halves + st1 = _mm256_permute2f128_si256(t1_, t3_, 0b00100000); // Combine low halves + + __m256i stored; + Goldilocks::add_avx(stored, st0, st1); + Goldilocks::add_avx(stored, stored, st2); + + Goldilocks::add_avx(st0, st0, stored); + Goldilocks::add_avx(st1, st1, stored); + Goldilocks::add_avx(st2, st2, stored); +}; + +inline void Poseidon2Goldilocks::pow7_avx(__m256i &st0, __m256i &st1, __m256i &st2) +{ + __m256i pw2_0, pw2_1, pw2_2; + Goldilocks::square_avx(pw2_0, st0); + Goldilocks::square_avx(pw2_1, st1); + Goldilocks::square_avx(pw2_2, st2); + __m256i pw4_0, pw4_1, pw4_2; + Goldilocks::square_avx(pw4_0, pw2_0); + Goldilocks::square_avx(pw4_1, pw2_1); + Goldilocks::square_avx(pw4_2, pw2_2); + __m256i pw3_0, pw3_1, pw3_2; + Goldilocks::mult_avx(pw3_0, pw2_0, st0); + Goldilocks::mult_avx(pw3_1, pw2_1, st1); + Goldilocks::mult_avx(pw3_2, pw2_2, st2); + + Goldilocks::mult_avx(st0, pw3_0, pw4_0); + Goldilocks::mult_avx(st1, pw3_1, pw4_1); + Goldilocks::mult_avx(st2, pw3_2, pw4_2); +}; + +inline void Poseidon2Goldilocks::add_avx(__m256i &st0, __m256i &st1, __m256i &st2, const Goldilocks::Element C_[SPONGE_WIDTH]) +{ + __m256i c0, c1, c2; + Goldilocks::load_avx(c0, &(C_[0])); + Goldilocks::load_avx(c1, &(C_[4])); + Goldilocks::load_avx(c2, &(C_[8])); + Goldilocks::add_avx(st0, st0, c0); + Goldilocks::add_avx(st1, st1, c1); + Goldilocks::add_avx(st2, st2, c2); +} +// Assuming C_a is aligned +inline void Poseidon2Goldilocks::add_avx_a(__m256i &st0, __m256i &st1, __m256i &st2, const Goldilocks::Element C_a[SPONGE_WIDTH]) +{ + __m256i c0, c1, c2; + Goldilocks::load_avx_a(c0, &(C_a[0])); + Goldilocks::load_avx_a(c1, &(C_a[4])); + Goldilocks::load_avx_a(c2, &(C_a[8])); + Goldilocks::add_avx(st0, st0, c0); + Goldilocks::add_avx(st1, st1, c1); + Goldilocks::add_avx(st2, st2, c2); +} +inline void Poseidon2Goldilocks::add_avx_small(__m256i &st0, __m256i &st1, __m256i &st2, const Goldilocks::Element C_small[SPONGE_WIDTH]) +{ + __m256i c0, c1, c2; + Goldilocks::load_avx(c0, &(C_small[0])); + Goldilocks::load_avx(c1, &(C_small[4])); + Goldilocks::load_avx(c2, &(C_small[8])); + + Goldilocks::add_avx_b_small(st0, st0, c0); + Goldilocks::add_avx_b_small(st1, st1, c1); + Goldilocks::add_avx_b_small(st2, st2, c2); +} + +#endif \ No newline at end of file diff --git a/pil2-stark/src/goldilocks/src/poseidon2_goldilocks_avx512.hpp b/pil2-stark/src/goldilocks/src/poseidon2_goldilocks_avx512.hpp new file mode 100644 index 000000000..ef3c9e252 --- /dev/null +++ b/pil2-stark/src/goldilocks/src/poseidon2_goldilocks_avx512.hpp @@ -0,0 +1,107 @@ +#ifndef POSEIDON2_GOLDILOCKS_AVX512 +#define POSEIDON2_GOLDILOCKS_AVX512 +#ifdef __AVX512__ +#include "poseidon2_goldilocks.hpp" +#include "goldilocks_base_field.hpp" +#include + +inline void Poseidon2Goldilocks::hash_avx512(Goldilocks::Element (&state)[2 * CAPACITY], Goldilocks::Element const (&input)[2 * SPONGE_WIDTH]) +{ + Goldilocks::Element aux[2 * SPONGE_WIDTH]; + hash_full_result_avx512(aux, input); + std::memcpy(state, aux, 2 * CAPACITY * sizeof(Goldilocks::Element)); +} + +inline void Poseidon2Goldilocks::matmul_external_avx512(__m512i &st0, __m512i &st1, __m512i &st2) +{ + __m512i indx1 = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + __m512i indx2 = _mm512_set_epi64(15, 14, 7, 6, 11, 10, 3, 2); + + __m512i t0 = _mm512_permutex2var_epi64(st0, indx1, st2); + __m512i t1 = _mm512_permutex2var_epi64(st1, indx1, zero); + __m512i t2 = _mm512_permutex2var_epi64(st0, indx2, st2); + __m512i t3 = _mm512_permutex2var_epi64(st1, indx2, zero); + + __m512i c0 = _mm512_castpd_si512(_mm512_unpacklo_pd(_mm512_castsi512_pd(t0), _mm512_castsi512_pd(t1))); + __m512i c1 = _mm512_castpd_si512(_mm512_unpackhi_pd(_mm512_castsi512_pd(t0), _mm512_castsi512_pd(t1))); + __m512i c2 = _mm512_castpd_si512(_mm512_unpacklo_pd(_mm512_castsi512_pd(t2), _mm512_castsi512_pd(t3))); + __m512i c3 = _mm512_castpd_si512(_mm512_unpackhi_pd(_mm512_castsi512_pd(t2), _mm512_castsi512_pd(t3))); + + __m512i t0, t0_2, t1, t1_2, t2, t3, t4, t5, t6, t7; + Goldilocks::add_avx512(t0, c0, c1); + Goldilocks::add_avx512(t1, c2, c3); + Goldilocks::add_avx512(t2, c1, c1); + Goldilocks::add_avx512(t2, t2, t1); + Goldilocks::add_avx512(t3, c3, c3); + Goldilocks::add_avx512(t3, t3, t0); + Goldilocks::add_avx512(t1_2, t1, t1); + Goldilocks::add_avx512(t0_2, t0, t0); + Goldilocks::add_avx512(t4, t1_2, t1_2); + Goldilocks::add_avx512(t4, t4, t3); + Goldilocks::add_avx512(t5, t0_2, t0_2); + Goldilocks::add_avx512(t5, t5, t2); + Goldilocks::add_avx512(t6, t3, t5); + Goldilocks::add_avx512(t7, t2, t4); + + // Step 1: Reverse unpacking + t0_ = _mm512_castpd_si512(_mm512_unpacklo_pd(_mm512_castsi512_pd(t6), _mm512_castsi512_pd(t5))); + t1_ = _mm512_castpd_si512(_mm512_unpackhi_pd(_mm512_castsi512_pd(t6), _mm512_castsi512_pd(t5))); + t2_ = _mm512_castpd_si512(_mm512_unpacklo_pd(_mm512_castsi512_pd(t7), _mm512_castsi512_pd(t4))); + t3_ = _mm512_castpd_si512(_mm512_unpackhi_pd(_mm512_castsi512_pd(t7), _mm512_castsi512_pd(t4))); + + // Step 2: Reverse _mm512_permutex2var_epi64 + + + __m512i stored; + Goldilocks::add_avx512(stored, st0, st1); + Goldilocks::add_avx512(stored, stored, st2); + + Goldilocks::add_avx512(st0, st0, stored); + Goldilocks::add_avx512(st1, st1, stored); + Goldilocks::add_avx512(st2, st2, stored); +}; + + + +inline void Poseidon2Goldilocks::pow7_avx512(__m512i &st0, __m512i &st1, __m512i &st2) +{ + __m512i pw2_0, pw2_1, pw2_2; + Goldilocks::square_avx512(pw2_0, st0); + Goldilocks::square_avx512(pw2_1, st1); + Goldilocks::square_avx512(pw2_2, st2); + __m512i pw4_0, pw4_1, pw4_2; + Goldilocks::square_avx512(pw4_0, pw2_0); + Goldilocks::square_avx512(pw4_1, pw2_1); + Goldilocks::square_avx512(pw4_2, pw2_2); + __m512i pw3_0, pw3_1, pw3_2; + Goldilocks::mult_avx512(pw3_0, pw2_0, st0); + Goldilocks::mult_avx512(pw3_1, pw2_1, st1); + Goldilocks::mult_avx512(pw3_2, pw2_2, st2); + + Goldilocks::mult_avx512(st0, pw3_0, pw4_0); + Goldilocks::mult_avx512(st1, pw3_1, pw4_1); + Goldilocks::mult_avx512(st2, pw3_2, pw4_2); +}; + +inline void Poseidon2Goldilocks::add_avx512(__m512i &st0, __m512i &st1, __m512i &st2, const Goldilocks::Element C_[SPONGE_WIDTH]) +{ + __m512i c0 = _mm512_set4_epi64(C_[3].fe, C_[2].fe, C_[1].fe, C_[0].fe); + __m512i c1 = _mm512_set4_epi64(C_[7].fe, C_[6].fe, C_[5].fe, C_[4].fe); + __m512i c2 = _mm512_set4_epi64(C_[11].fe, C_[10].fe, C_[9].fe, C_[8].fe); + Goldilocks::add_avx512(st0, st0, c0); + Goldilocks::add_avx512(st1, st1, c1); + Goldilocks::add_avx512(st2, st2, c2); +} + +inline void Poseidon2Goldilocks::add_avx512_small(__m512i &st0, __m512i &st1, __m512i &st2, const Goldilocks::Element C_small[SPONGE_WIDTH]) +{ + __m512i c0 = _mm512_set4_epi64(C_small[3].fe, C_small[2].fe, C_small[1].fe, C_small[0].fe); + __m512i c1 = _mm512_set4_epi64(C_small[7].fe, C_small[6].fe, C_small[5].fe, C_small[4].fe); + __m512i c2 = _mm512_set4_epi64(C_small[11].fe, C_small[10].fe, C_small[9].fe, C_small[8].fe); + + Goldilocks::add_avx512_b_c(st0, st0, c0); + Goldilocks::add_avx512_b_c(st1, st1, c1); + Goldilocks::add_avx512_b_c(st2, st2, c2); +} +#endif +#endif \ No newline at end of file diff --git a/pil2-stark/src/goldilocks/src/poseidon2_goldilocks_constants.hpp b/pil2-stark/src/goldilocks/src/poseidon2_goldilocks_constants.hpp new file mode 100644 index 000000000..b5dfb7053 --- /dev/null +++ b/pil2-stark/src/goldilocks/src/poseidon2_goldilocks_constants.hpp @@ -0,0 +1,145 @@ +#ifndef POSEIDON2_GOLDILOCKS_CONSTANTS +#define POSEIDON2_GOLDILOCKS_CONSTANTS +#endif // POSEIDON2_GOLDILOCKS_CONSTANTS +#include "goldilocks_base_field.hpp" + +namespace Poseidon2GoldilocksConstants +{ + + inline constexpr static Goldilocks::Element C[118] = { + {0x13dcf33aba214f46}, + {0x30b3b654a1da6d83}, + {0x1fc634ada6159b56}, + {0x937459964dc03466}, + {0xedd2ef2ca7949924}, + {0xede9affde0e22f68}, + {0x8515b9d6bac9282d}, + {0x6b5c07b4e9e900d8}, + {0x1ec66368838c8a08}, + {0x9042367d80d1fbab}, + {0x400283564a3c3799}, + {0x4a00be0466bca75e}, + {0x7913beee58e3817f}, + {0xf545e88532237d90}, + {0x22f8cb8736042005}, + {0x6f04990e247a2623}, + {0xfe22e87ba37c38cd}, + {0xd20e32c85ffe2815}, + {0x117227674048fe73}, + {0x4e9fb7ea98a6b145}, + {0xe0866c232b8af08b}, + {0x00bbc77916884964}, + {0x7031c0fb990d7116}, + {0x240a9e87cf35108f}, + {0x2e6363a5a12244b3}, + {0x5e1c3787d1b5011c}, + {0x4132660e2a196e8b}, + {0x3a013b648d3d4327}, + {0xf79839f49888ea43}, + {0xfe85658ebafe1439}, + {0xb6889825a14240bd}, + {0x578453605541382b}, + {0x4508cda8f6b63ce9}, + {0x9c3ef35848684c91}, + {0x0812bde23c87178c}, + {0xfe49638f7f722c14}, + {0x8e3f688ce885cbf5}, + {0xb8e110acf746a87d}, + {0xb4b2e8973a6dabef}, + {0x9e714c5da3d462ec}, + {0x6438f9033d3d0c15}, + {0x24312f7cf1a27199}, + {0x23f843bb47acbf71}, + {0x9183f11a34be9f01}, + {0x839062fbb9d45dbf}, + {0x24b56e7e6c2e43fa}, + {0xe1683da61c962a72}, + {0xa95c63971a19bfa7}, + {0x4adf842aa75d4316}, + {0xf8fbb871aa4ab4eb}, + {0x68e85b6eb2dd6aeb}, + {0x07a0b06b2d270380}, + {0xd94e0228bd282de4}, + {0x8bdd91d3250c5278}, + {0x209c68b88bba778f}, + {0xb5e18cdab77f3877}, + {0xb296a3e808da93fa}, + {0x8370ecbda11a327e}, + {0x3f9075283775dad8}, + {0xb78095bb23c6aa84}, + {0x3f36b9fe72ad4e5f}, + {0x69bc96780b10b553}, + {0x3f1d341f2eb7b881}, + {0x4e939e9815838818}, + {0xda366b3ae2a31604}, + {0xbc89db1e7287d509}, + {0x6102f411f9ef5659}, + {0x58725c5e7ac1f0ab}, + {0x0df5856c798883e7}, + {0xf7bb62a8da4c961b}, + {0xc68be7c94882a24d}, + {0xaf996d5d5cdaedd9}, + {0x9717f025e7daf6a5}, + {0x6436679e6e7216f4}, + {0x8a223d99047af267}, + {0xbb512e35a133ba9a}, + {0xfbbf44097671aa03}, + {0xf04058ebf6811e61}, + {0x5cca84703fac7ffb}, + {0x9b55c7945de6469f}, + {0x8e05bf09808e934f}, + {0x2ea900de876307d7}, + {0x7748fff2b38dfb89}, + {0x6b99a676dd3b5d81}, + {0xac4bb7c627cf7c13}, + {0xadb6ebe5e9e2f5ba}, + {0x2d33378cafa24ae3}, + {0x1e5b73807543f8c2}, + {0x09208814bfebb10f}, + {0x782e64b6bb5b93dd}, + {0xadd5a48eac90b50f}, + {0xadd4c54c736ea4b1}, + {0xd58dbb86ed817fd8}, + {0x6d5ed1a533f34ddd}, + {0x28686aa3e36b7cb9}, + {0x591abd3476689f36}, + {0x047d766678f13875}, + {0xa2a11112625f5b49}, + {0x21fd10a3f8304958}, + {0xf9b40711443b0280}, + {0xd2697eb8b2bde88e}, + {0x3493790b51731b3f}, + {0x11caf9dd73764023}, + {0x7acfb8f72878164e}, + {0x744ec4db23cefc26}, + {0x1e00e58f422c6340}, + {0x21dd28d906a62dda}, + {0xf32a46ab5f465b5f}, + {0xbfce13201f3f7e6b}, + {0xf30d2e7adb5304e2}, + {0xecdf4ee4abad48e9}, + {0xf94e82182d395019}, + {0x4ee52e3744d887c5}, + {0xa1341c7cac0083b2}, + {0x2302fb26c30c834a}, + {0xaea3c587273bf7d3}, + {0xf798e24961823ec7}, + {0x962deba3e9a2cd94} + }; + + inline constexpr static Goldilocks::Element D[12] = { + {0xc3b6c08e23ba9300}, + {0xd84b5de94a324fb6}, + {0x0d0c371c5b35b84f}, + {0x7964f570e7188037}, + {0x5daf18bbd996604b}, + {0x6743bc47b9595257}, + {0x5528b9362c59bb70}, + {0xac45e25b7127b68b}, + {0xa2077d7dfbb606b5}, + {0xf3faac6faee378ae}, + {0x0c6388b51545e883}, + {0xd27dbb6944917b60}, + }; + +} \ No newline at end of file diff --git a/pil2-stark/src/goldilocks/tests/tests.cpp b/pil2-stark/src/goldilocks/tests/tests.cpp index 17968a9a9..043e746f3 100644 --- a/pil2-stark/src/goldilocks/tests/tests.cpp +++ b/pil2-stark/src/goldilocks/tests/tests.cpp @@ -4,6 +4,7 @@ #include "../src/goldilocks_base_field.hpp" #include "../src/goldilocks_cubic_extension.hpp" #include "../src/poseidon_goldilocks.hpp" +#include "../src/poseidon2_goldilocks.hpp" #include "../src/ntt_goldilocks.hpp" #include "../src/merklehash_goldilocks.hpp" #include @@ -1486,7 +1487,43 @@ TEST(GOLDILOCKS_TEST, inv) ASSERT_EQ(Goldilocks::inv(inE1), Goldilocks::inv(inE1_plus_p)); } -TEST(GOLDILOCKS_TEST, poseidon_avx_seq) +TEST(GOLDILOCKS_TEST, poseidon2_seq) +{ + Goldilocks::Element x[SPONGE_WIDTH]; + Goldilocks::Element result[CAPACITY]; + + for (uint64_t i = 0; i < SPONGE_WIDTH; i++) + { + x[i] = Goldilocks::fromU64(i); + } + + Poseidon2Goldilocks::hash_seq(result, x); + + ASSERT_EQ(Goldilocks::toU64(result[0]), 0X1EAEF96BDF1C0C1 ); + ASSERT_EQ(Goldilocks::toU64(result[1]), 0X1F0D2CC525B2540C); + ASSERT_EQ(Goldilocks::toU64(result[2]), 0X6282C1DFE1E0358D); + ASSERT_EQ(Goldilocks::toU64(result[3]), 0XE780D721F698E1E6); +} + +TEST(GOLDILOCKS_TEST, poseidon2_avx) +{ + Goldilocks::Element x[SPONGE_WIDTH]; + Goldilocks::Element result[CAPACITY]; + + for (uint64_t i = 0; i < SPONGE_WIDTH; i++) + { + x[i] = Goldilocks::fromU64(i); + } + + Poseidon2Goldilocks::hash(result, x); + + ASSERT_EQ(Goldilocks::toU64(result[0]), 0X1EAEF96BDF1C0C1 ); + ASSERT_EQ(Goldilocks::toU64(result[1]), 0X1F0D2CC525B2540C); + ASSERT_EQ(Goldilocks::toU64(result[2]), 0X6282C1DFE1E0358D); + ASSERT_EQ(Goldilocks::toU64(result[3]), 0XE780D721F698E1E6); +} + +TEST(GOLDILOCKS_TEST, poseidon_seq) { Goldilocks::Element fibonacci[SPONGE_WIDTH]; diff --git a/pil2-stark/src/starkpil/const_pols.hpp b/pil2-stark/src/starkpil/const_pols.hpp index 9653c07c4..4e22d3a25 100644 --- a/pil2-stark/src/starkpil/const_pols.hpp +++ b/pil2-stark/src/starkpil/const_pols.hpp @@ -16,45 +16,21 @@ class ConstTree { public: ConstTree () {}; - uint64_t getNumNodes(StarkInfo& starkInfo) { - uint64_t merkleTreeArity = starkInfo.starkStruct.verificationHashType == std::string("BN128") ? starkInfo.starkStruct.merkleTreeArity : 2; - uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; - uint n_tmp = NExtended; - uint64_t nextN = floor(((double)(n_tmp - 1) / merkleTreeArity) + 1); - uint64_t acc = nextN * merkleTreeArity; - while (n_tmp > 1) - { - // FIll with zeros if n nodes in the leve is not even - n_tmp = nextN; - nextN = floor((n_tmp - 1) / merkleTreeArity) + 1; - if (n_tmp > 1) - { - acc += nextN * merkleTreeArity; - } - else - { - acc += 1; - } - } - - return acc; - } - - uint64_t getConstTreeSizeBytesBN128(StarkInfo& starkInfo) + uint64_t getConstTreeSizeBN128(StarkInfo& starkInfo) { uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; - uint64_t acc = getNumNodes(starkInfo); - return 16 + (NExtended * starkInfo.nConstants) * sizeof(Goldilocks::Element) + acc * sizeof(RawFr::Element); + MerkleTreeBN128 mt(starkInfo.starkStruct.merkleTreeArity, starkInfo.starkStruct.merkleTreeCustom, NExtended, starkInfo.nConstants); + return 2 + (NExtended * starkInfo.nConstants) + mt.numNodes * (sizeof(RawFr::Element) / sizeof(Goldilocks::Element)); } - uint64_t getConstTreeSizeBytesGL(StarkInfo& starkInfo) + uint64_t getConstTreeSizeGL(StarkInfo& starkInfo) { uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; - uint64_t acc = getNumNodes(starkInfo); - return (2 + (NExtended * starkInfo.nConstants) + acc * HASH_SIZE) * sizeof(Goldilocks::Element); + MerkleTreeGL mt(2, true, NExtended, starkInfo.nConstants); + return (2 + (NExtended * starkInfo.nConstants) + mt.numNodes); } - void calculateConstTreeGL(StarkInfo& starkInfo, Goldilocks::Element *pConstPolsAddress, void *treeAddress, std::string constTreeFile) { + void calculateConstTreeGL(StarkInfo& starkInfo, Goldilocks::Element *pConstPolsAddress, void *treeAddress) { uint64_t N = 1 << starkInfo.starkStruct.nBits; uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; NTT_Goldilocks ntt(N); @@ -68,15 +44,16 @@ class ConstTree { treeAddressGL[0] = Goldilocks::fromU64(starkInfo.nConstants); treeAddressGL[1] = Goldilocks::fromU64(NExtended); + } - if(constTreeFile != "") { - TimerStart(WRITING_TREE_FILE); - mt.writeFile(constTreeFile); - TimerStopAndLog(WRITING_TREE_FILE); - } + void writeConstTreeFileGL(StarkInfo& starkInfo, void *treeAddress, std::string constTreeFile) { + TimerStart(WRITING_TREE_FILE); + MerkleTreeGL mt(2, true, (Goldilocks::Element *)treeAddress); + mt.writeFile(constTreeFile); + TimerStopAndLog(WRITING_TREE_FILE); } - void calculateConstTreeBN128(StarkInfo& starkInfo, Goldilocks::Element *pConstPolsAddress, void *treeAddress, std::string constTreeFile) { + void calculateConstTreeBN128(StarkInfo& starkInfo, Goldilocks::Element *pConstPolsAddress, void *treeAddress) { uint64_t N = 1 << starkInfo.starkStruct.nBits; uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; NTT_Goldilocks ntt(N); @@ -89,16 +66,47 @@ class ConstTree { treeAddressGL[0] = Goldilocks::fromU64(starkInfo.nConstants); treeAddressGL[1] = Goldilocks::fromU64(NExtended); + } - if(constTreeFile != "") { - TimerStart(WRITING_TREE_FILE); - mt.writeFile(constTreeFile); - TimerStopAndLog(WRITING_TREE_FILE); - } + void writeConstTreeFileBN128(StarkInfo& starkInfo, void *treeAddress, std::string constTreeFile) { + TimerStart(WRITING_TREE_FILE); + MerkleTreeBN128 mt(starkInfo.starkStruct.merkleTreeArity, starkInfo.starkStruct.merkleTreeCustom, (Goldilocks::Element *)treeAddress); + mt.writeFile(constTreeFile); + TimerStopAndLog(WRITING_TREE_FILE); } - void loadConstTree(void *constTreePols, std::string constTreeFile, uint64_t constTreeSize) { - loadFileParallel(constTreePols, constTreeFile, constTreeSize); + bool loadConstTree(StarkInfo &starkInfo, void *constTreePols, std::string constTreeFile, uint64_t constTreeSize, std::string verkeyFile) { + bool fileLoaded = loadFileParallel(constTreePols, constTreeFile, constTreeSize, false); + if(!fileLoaded) { + return false; + } + + json verkeyJson; + file2json(verkeyFile, verkeyJson); + + if (starkInfo.starkStruct.verificationHashType == "BN128") { + MerkleTreeBN128 mt(starkInfo.starkStruct.merkleTreeArity, starkInfo.starkStruct.merkleTreeCustom, (Goldilocks::Element *)constTreePols); + RawFr::Element root[1]; + mt.getRoot(root); + if(RawFr::field.toString(root[0], 10) != verkeyJson) { + return false; + } + } else { + MerkleTreeGL mt(2, true, (Goldilocks::Element *)constTreePols); + Goldilocks::Element root[4]; + mt.getRoot(root); + + if (Goldilocks::toU64(root[0]) != verkeyJson[0] || + Goldilocks::toU64(root[1]) != verkeyJson[1] || + Goldilocks::toU64(root[2]) != verkeyJson[2] || + Goldilocks::toU64(root[3]) != verkeyJson[3]) + { + return false; + } + + } + + return true; } void loadConstPols(void *constPols, std::string constPolsFile, uint64_t constPolsSize) { diff --git a/pil2-stark/src/starkpil/fixed_cols.hpp b/pil2-stark/src/starkpil/fixed_cols.hpp new file mode 100644 index 000000000..e75263538 --- /dev/null +++ b/pil2-stark/src/starkpil/fixed_cols.hpp @@ -0,0 +1,50 @@ +#include +#include +#include "binfile_utils.hpp" +#include "binfile_writer.hpp" +#include "polinomial.hpp" +#include "goldilocks_base_field.hpp" +#include "goldilocks_base_field_avx.hpp" +#include "goldilocks_base_field_avx512.hpp" +#include "goldilocks_base_field_pack.hpp" +#include "goldilocks_cubic_extension.hpp" +#include "goldilocks_cubic_extension_pack.hpp" +#include "goldilocks_cubic_extension_avx.hpp" +#include "goldilocks_cubic_extension_avx512.hpp" +#include "stark_info.hpp" +#include +#include + +const int FIXED_POLS_SECTION = 1; +const int N_SECTIONS = 1; + +struct FixedPolsInfo { + uint64_t name_size; + uint8_t *name; + uint64_t n_lengths; + uint64_t *lengths; + Goldilocks::Element *values; +}; + +void writeFixedColsBin(string binFileName, string airgroupName, string airName, uint64_t N, uint64_t nFixedPols, FixedPolsInfo* fixedPolsInfo) { + BinFileUtils::BinFileWriter binFile(binFileName, "cnst", 1, N_SECTIONS); + + binFile.startWriteSection(FIXED_POLS_SECTION); + + binFile.writeString(airgroupName); + binFile.writeString(airName); + binFile.writeU64LE(N); + binFile.writeU32LE(nFixedPols); + for(uint64_t i = 0; i < nFixedPols; ++i) { + std::string name = std::string((char *)fixedPolsInfo[i].name, fixedPolsInfo[i].name_size); + binFile.writeString(name); + binFile.writeU32LE(fixedPolsInfo[i].n_lengths); + for(uint64_t j = 0; j < fixedPolsInfo[i].n_lengths; ++j) { + binFile.writeU32LE(fixedPolsInfo[i].lengths[j]); + } + + binFile.write((void *)fixedPolsInfo[i].values, N * sizeof(Goldilocks::Element)); + } + + binFile.endWriteSection(); +} diff --git a/pil2-stark/src/starkpil/gen_recursive_proof.hpp b/pil2-stark/src/starkpil/gen_recursive_proof.hpp index de7c66b89..fe3b59c64 100644 --- a/pil2-stark/src/starkpil/gen_recursive_proof.hpp +++ b/pil2-stark/src/starkpil/gen_recursive_proof.hpp @@ -1,10 +1,10 @@ #include "starks.hpp" template -void *genRecursiveProof(SetupCtx& setupCtx, json& globalInfo, uint64_t airgroupId, Goldilocks::Element *witness, Goldilocks::Element *aux_trace, Goldilocks::Element *pConstPols, Goldilocks::Element *pConstTree, Goldilocks::Element *publicInputs, std::string proofFile, bool vadcop) { +void *genRecursiveProof(SetupCtx& setupCtx, json& globalInfo, uint64_t airgroupId, uint64_t airId, uint64_t instanceId, Goldilocks::Element *witness, Goldilocks::Element *aux_trace, Goldilocks::Element *pConstPols, Goldilocks::Element *pConstTree, Goldilocks::Element *publicInputs, std::string proofFile, bool vadcop) { TimerStart(STARK_PROOF); - FRIProof proof(setupCtx.starkInfo, 0); + FRIProof proof(setupCtx.starkInfo, airgroupId, airId, instanceId); using TranscriptType = std::conditional_t::value, TranscriptGL, TranscriptBN128>; diff --git a/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.cpp b/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.cpp index 9aad00ffd..b11a71b12 100644 --- a/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.cpp +++ b/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.cpp @@ -3,12 +3,23 @@ #include // std::max #include -MerkleTreeBN128::MerkleTreeBN128(uint64_t _arity, bool _custom, uint64_t _height, uint64_t _width) : height(_height), width(_width) +MerkleTreeBN128::MerkleTreeBN128(uint64_t _arity, bool _custom, uint64_t _height, uint64_t _width, bool allocateSource, bool allocateNodes) : height(_height), width(_width) { - numNodes = getNumNodes(height); arity = _arity; custom = _custom; + numNodes = getNumNodes(height); + + if(allocateSource) { + source = (Goldilocks::Element *)calloc(height * width, sizeof(Goldilocks::Element)); + isSourceAllocated = true; + } + + if(allocateNodes) { + nodes = (RawFr::Element *)calloc(numNodes, sizeof(RawFr::Element)); + isNodesAllocated = true; + } + } MerkleTreeBN128::MerkleTreeBN128(uint64_t _arity, bool _custom, Goldilocks::Element *tree) @@ -19,10 +30,20 @@ MerkleTreeBN128::MerkleTreeBN128(uint64_t _arity, bool _custom, Goldilocks::Elem arity = _arity; custom = _custom; numNodes = getNumNodes(height); - nodes = (RawFr::Element *)&source[width * height]; } +MerkleTreeBN128::~MerkleTreeBN128() +{ + if(isSourceAllocated) { + free(source); + } + + if(isNodesAllocated) { + free(nodes); + } +} + uint64_t MerkleTreeBN128::getNumSiblings() { return arity * nFieldElements; @@ -76,11 +97,21 @@ void MerkleTreeBN128::getRoot(RawFr::Element *root) void MerkleTreeBN128::setSource(Goldilocks::Element *_source) { + if(isSourceAllocated) { + zklog.error("MerkleTreeBN128: Source was allocated when initializing"); + exitProcess(); + exit(-1); + } source = _source; } void MerkleTreeBN128::setNodes(RawFr::Element *_nodes) { + if(isNodesAllocated) { + zklog.error("MerkleTreeBN128: Nodes were allocated when initializing"); + exitProcess(); + exit(-1); + } nodes = _nodes; } diff --git a/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.hpp b/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.hpp index 11a957c3b..2840d8321 100644 --- a/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.hpp +++ b/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.hpp @@ -21,8 +21,8 @@ class MerkleTreeBN128 public: MerkleTreeBN128(){}; MerkleTreeBN128(uint64_t arity, bool custom, Goldilocks::Element *tree); - MerkleTreeBN128(uint64_t arity, bool custom, uint64_t _height, uint64_t _width); - ~MerkleTreeBN128(){}; + MerkleTreeBN128(uint64_t arity, bool custom, uint64_t _height, uint64_t _width, bool allocateSource = false, bool allocateNodes = false); + ~MerkleTreeBN128(); uint64_t numNodes; uint64_t height; diff --git a/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.cpp b/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.cpp index 33cab74f6..722ca65a7 100644 --- a/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.cpp +++ b/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.cpp @@ -3,11 +3,21 @@ #include // std::max -MerkleTreeGL::MerkleTreeGL(uint64_t _arity, bool _custom, uint64_t _height, uint64_t _width) : height(_height), width(_width) +MerkleTreeGL::MerkleTreeGL(uint64_t _arity, bool _custom, uint64_t _height, uint64_t _width, bool allocateSource, bool allocateNodes) : height(_height), width(_width) { numNodes = getNumNodes(height); arity = _arity; custom = _custom; + + if(allocateSource) { + source = (Goldilocks::Element *)calloc(height * width, sizeof(Goldilocks::Element)); + isSourceAllocated = true; + } + + if(allocateNodes) { + nodes = (Goldilocks::Element *)calloc(numNodes, sizeof(Goldilocks::Element)); + isNodesAllocated = true; + } }; MerkleTreeGL::MerkleTreeGL(uint64_t _arity, bool _custom, Goldilocks::Element *tree) @@ -21,6 +31,17 @@ MerkleTreeGL::MerkleTreeGL(uint64_t _arity, bool _custom, Goldilocks::Element *t nodes = &tree[2 + height * width]; }; +MerkleTreeGL::~MerkleTreeGL() +{ + if(isSourceAllocated) { + free(source); + } + + if(isNodesAllocated) { + free(nodes); + } +} + uint64_t MerkleTreeGL::getNumSiblings() { return (arity - 1) * nFieldElements; @@ -55,11 +76,23 @@ void MerkleTreeGL::getRoot(Goldilocks::Element *root) void MerkleTreeGL::setSource(Goldilocks::Element *_source) { + if(isSourceAllocated) { + zklog.error("MerkleTreeGL: Source was allocated when initializing"); + exitProcess(); + exit(-1); + } source = _source; } void MerkleTreeGL::setNodes(Goldilocks::Element *_nodes) { + if(isNodesAllocated) { + if(isNodesAllocated) { + zklog.error("MerkleTreeGL: Nodes were allocated when initializing"); + exitProcess(); + exit(-1); + } + } nodes = _nodes; } @@ -97,12 +130,11 @@ void MerkleTreeGL::genMerkleProof(Goldilocks::Element *proof, uint64_t idx, uint bool MerkleTreeGL::verifyGroupProof(Goldilocks::Element* root, std::vector> &mp, uint64_t idx, std::vector &v) { Goldilocks::Element value[4] = { Goldilocks::zero(), Goldilocks::zero(), Goldilocks::zero(), Goldilocks::zero() }; - PoseidonGoldilocks::linear_hash_seq(value, v.data(), v.size()); + Poseidon2Goldilocks::linear_hash_seq(value, v.data(), v.size()); calculateRootFromProof(value, mp, idx, 0); for(uint64_t i = 0; i < 4; ++i) { if(Goldilocks::toU64(value[i]) != Goldilocks::toU64(root[i])) { - cout << Goldilocks::toU64(value[0]) << " " << Goldilocks::toU64(value[1]) << " " << Goldilocks::toU64(value[2]) << " " << Goldilocks::toU64(value[3]) << endl; return false; } } @@ -130,7 +162,7 @@ void MerkleTreeGL::calculateRootFromProof(Goldilocks::Element (&value)[4], std:: inputs[i] = Goldilocks::zero(); } - PoseidonGoldilocks::hash_seq(value, inputs); + Poseidon2Goldilocks::hash_seq(value, inputs); calculateRootFromProof(value, mp, nextIdx, offset + 1); } @@ -139,11 +171,12 @@ void MerkleTreeGL::calculateRootFromProof(Goldilocks::Element (&value)[4], std:: void MerkleTreeGL::merkelize() { #ifdef __AVX512__ - PoseidonGoldilocks::merkletree_avx512(nodes, source, width, height); + // Poseidon2Goldilocks::merkletree_avx512(nodes, source, width, height); // AVX512 is not supported yet + Poseidon2Goldilocks::merkletree_avx(nodes, source, width, height); #elif defined(__AVX2__) - PoseidonGoldilocks::merkletree_avx(nodes, source, width, height); + Poseidon2Goldilocks::merkletree_avx(nodes, source, width, height); #else - PoseidonGoldilocks::merkletree_seq(nodes, source, width, height); + Poseidon2Goldilocks::merkletree_seq(nodes, source, width, height); #endif } diff --git a/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.hpp b/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.hpp index ae4ddbd97..4effb35da 100644 --- a/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.hpp +++ b/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.hpp @@ -2,7 +2,7 @@ #define MERKLETREEGL #include "goldilocks_base_field.hpp" -#include "poseidon_goldilocks.hpp" +#include "poseidon2_goldilocks.hpp" #include "zklog.hpp" #include @@ -16,8 +16,8 @@ class MerkleTreeGL public: MerkleTreeGL(){}; MerkleTreeGL(uint64_t _arity, bool custom, Goldilocks::Element *tree); - MerkleTreeGL(uint64_t _arity, bool custom, uint64_t _height, uint64_t _width); - ~MerkleTreeGL(){}; + MerkleTreeGL(uint64_t _arity, bool custom, uint64_t _height, uint64_t _width, bool allocateSource = false, bool allocateNodes = false); + ~MerkleTreeGL(); uint64_t numNodes; uint64_t height; @@ -29,6 +29,9 @@ class MerkleTreeGL uint64_t arity; bool custom; + bool isSourceAllocated = false; + bool isNodesAllocated = false; + uint64_t nFieldElements = HASH_SIZE; uint64_t getNumSiblings(); diff --git a/pil2-stark/src/starkpil/proof_stark.hpp b/pil2-stark/src/starkpil/proof_stark.hpp index 63057f907..b009bfcc4 100644 --- a/pil2-stark/src/starkpil/proof_stark.hpp +++ b/pil2-stark/src/starkpil/proof_stark.hpp @@ -2,7 +2,6 @@ #define PROOF #include "goldilocks_base_field.hpp" -#include "poseidon_goldilocks.hpp" #include "stark_info.hpp" #include "fr.hpp" #include @@ -99,8 +98,6 @@ class Proofs uint64_t nStages; uint64_t nCustomCommits; uint64_t nFieldElements; - uint64_t airId; - uint64_t airgroupId; ElementType **roots; Fri fri; std::vector> evals; @@ -119,8 +116,6 @@ class Proofs nCustomCommits = starkInfo_.customCommits.size(); roots = new ElementType*[nStages + nCustomCommits]; nFieldElements = starkInfo_.starkStruct.verificationHashType == "GL" ? HASH_SIZE : 1; - airId = starkInfo_.airId; - airgroupId = starkInfo_.airgroupId; for(uint64_t i = 0; i < nStages + nCustomCommits; i++) { roots[i] = new ElementType[nFieldElements]; @@ -286,12 +281,16 @@ class Proofs } for(uint64_t step = 1; step < starkInfo.starkStruct.steps.size(); ++step) { - j["s" + std::to_string(step) + "_root"] = json::array(); - for(uint64_t i = 0; i < nFieldElements; i++) { - j["s" + std::to_string(step) + "_root"][i] = toString(fri.treesFRI[step - 1].root[i]); + if(nFieldElements == 1) { + j["s" + std::to_string(step) + "_root"] = toString(fri.treesFRI[step - 1].root[0]); + } else { + j["s" + std::to_string(step) + "_root"] = json::array(); + for(uint64_t i = 0; i < nFieldElements; i++) { + j["s" + std::to_string(step) + "_root"][i] = toString(fri.treesFRI[step - 1].root[i]); + } + j["s" + std::to_string(step) + "_vals"] = json::array(); + j["s" + std::to_string(step) + "_siblings"] = json::array(); } - j["s" + std::to_string(step) + "_vals"] = json::array(); - j["s" + std::to_string(step) + "_siblings"] = json::array(); } for(uint64_t i = 0; i < starkInfo.starkStruct.nQueries; i++) { @@ -334,15 +333,16 @@ class FRIProof Proofs proof; std::vector publics; - uint64_t airId; uint64_t airgroupId; + uint64_t airId; uint64_t instanceId; - FRIProof(StarkInfo &starkInfo, uint64_t _instanceId) : proof(starkInfo), publics(starkInfo.nPublics) { - airId = starkInfo.airId; - airgroupId = starkInfo.airgroupId; - instanceId = _instanceId; - }; + FRIProof(StarkInfo &starkInfo, uint64_t _airgroupId, uint64_t _airId, uint64_t _instanceId) : + proof(starkInfo), + publics(starkInfo.nPublics), + airgroupId(_airgroupId), + airId(_airId), + instanceId(_instanceId) {}; }; #endif \ No newline at end of file diff --git a/pil2-stark/src/starkpil/stark_info.cpp b/pil2-stark/src/starkpil/stark_info.cpp index 3a97997b6..713403454 100644 --- a/pil2-stark/src/starkpil/stark_info.cpp +++ b/pil2-stark/src/starkpil/stark_info.cpp @@ -46,9 +46,6 @@ void StarkInfo::load(json j, bool verify_) starkStruct.steps.push_back(step); } - airId = j["airId"]; - airgroupId = j["airgroupId"]; - nPublics = j["nPublics"]; nConstants = j["nConstants"]; @@ -288,21 +285,28 @@ void StarkInfo::setMapOffsets() { mapOffsets[std::make_pair("evals", true)] = mapTotalN; mapTotalN += evMap.size() * omp_get_max_threads() * FIELD_EXTENSION; - // Merkle tree nodes sizes - for (uint64_t i = 0; i < nStages + 1; i++) { - uint64_t numNodes = getNumNodesMT(1 << starkStruct.nBitsExt); - mapOffsets[std::make_pair("mt" + to_string(i + 1), true)] = mapTotalN; - mapTotalN += numNodes; - } - for(uint64_t step = 0; step < starkStruct.steps.size() - 1; ++step) { uint64_t height = 1 << starkStruct.steps[step + 1].nBits; uint64_t width = ((1 << starkStruct.steps[step].nBits) / height) * FIELD_EXTENSION; - uint64_t numNodes = getNumNodesMT(height); mapOffsets[std::make_pair("fri_" + to_string(step + 1), true)] = mapTotalN; mapTotalN += height * width; - mapOffsets[std::make_pair("mt_fri_" + to_string(step + 1), true)] = mapTotalN; - mapTotalN += numNodes; + } + + if(starkStruct.verificationHashType == "GL") { + // Merkle tree nodes sizes + for (uint64_t i = 0; i < nStages + 1; i++) { + uint64_t numNodes = getNumNodesMT(1 << starkStruct.nBitsExt); + mapOffsets[std::make_pair("mt" + to_string(i + 1), true)] = mapTotalN; + mapTotalN += numNodes; + } + + + for(uint64_t step = 0; step < starkStruct.steps.size() - 1; ++step) { + uint64_t height = 1 << starkStruct.steps[step + 1].nBits; + uint64_t numNodes = getNumNodesMT(height); + mapOffsets[std::make_pair("mt_fri_" + to_string(step + 1), true)] = mapTotalN; + mapTotalN += numNodes; + } } } @@ -324,28 +328,7 @@ void StarkInfo::getPolynomial(Polinomial &pol, Goldilocks::Element *pAddress, st } uint64_t StarkInfo::getNumNodesMT(uint64_t height) { - if(starkStruct.verificationHashType == "BN128") { - uint n_tmp = height; - uint64_t nextN = floor(((double)(n_tmp - 1) / starkStruct.merkleTreeArity) + 1); - uint64_t acc = nextN * starkStruct.merkleTreeArity; - while (n_tmp > 1) - { - // FIll with zeros if n nodes in the leve is not even - n_tmp = nextN; - nextN = floor((n_tmp - 1) / starkStruct.merkleTreeArity) + 1; - if (n_tmp > 1) - { - acc += nextN * starkStruct.merkleTreeArity; - } - else - { - acc += 1; - } - } - return acc * sizeof(RawFr::Element) / sizeof(Goldilocks::Element); - } else { - return height * HASH_SIZE + (height - 1) * HASH_SIZE; - } + return height * HASH_SIZE + (height - 1) * HASH_SIZE; } opType string2opType(const string s) diff --git a/pil2-stark/src/starkpil/starks.cpp b/pil2-stark/src/starkpil/starks.cpp index 7076f5551..6988c5fde 100644 --- a/pil2-stark/src/starkpil/starks.cpp +++ b/pil2-stark/src/starkpil/starks.cpp @@ -13,7 +13,6 @@ void Starks::extendAndMerkelizeCustomCommit(uint64_t commitId, uint uint64_t nCols = setupCtx.starkInfo.mapSectionsN[section]; Goldilocks::Element *pBuff = buffer; Goldilocks::Element *pBuffExtended = bufferExt; - ElementType *pBuffNodes = (ElementType *)(&bufferExt[NExtended * nCols]); NTT_Goldilocks ntt(N); if(pBuffHelper != nullptr) { @@ -24,7 +23,11 @@ void Starks::extendAndMerkelizeCustomCommit(uint64_t commitId, uint uint64_t pos = setupCtx.starkInfo.nStages + 2 + commitId; treesGL[pos]->setSource(pBuffExtended); - treesGL[pos]->setNodes(pBuffNodes); + if(setupCtx.starkInfo.starkStruct.verificationHashType == "GL") { + Goldilocks::Element *pBuffNodesGL = &bufferExt[NExtended * nCols]; + ElementType *pBuffNodes = (ElementType *)pBuffNodesGL; + treesGL[pos]->setNodes(pBuffNodes); + } treesGL[pos]->merkelize(); treesGL[pos]->getRoot(&proof.proof.roots[pos - 1][0]); @@ -47,18 +50,23 @@ void Starks::loadCustomCommit(uint64_t commitId, uint64_t step, Gol uint64_t nCols = setupCtx.starkInfo.mapSectionsN[section]; Goldilocks::Element *pBuff = buffer; Goldilocks::Element *pBuffExtended = bufferExt; - ElementType *pBuffNodes = (ElementType *)(&bufferExt[NExtended * nCols]); + uint64_t pos = setupCtx.starkInfo.nStages + 2 + commitId; Goldilocks::Element* tmpBuff = (Goldilocks::Element *)loadFileParallel(bufferFile, ((N + NExtended) * nCols + treesGL[pos]->getNumNodes(NExtended)) * sizeof(Goldilocks::Element)); memcpy(pBuff, &tmpBuff[0], N * nCols * sizeof(Goldilocks::Element)); memcpy(pBuffExtended, &tmpBuff[N * nCols], NExtended * nCols * sizeof(Goldilocks::Element)); - ElementType *tmpBuffNodes = (ElementType *)(&tmpBuff[N * nCols + NExtended * nCols]); - memcpy(pBuffNodes, tmpBuffNodes, treesGL[pos]->numNodes * sizeof(ElementType)); + treesGL[pos]->setSource(pBuffExtended); - treesGL[pos]->setNodes(pBuffNodes); + if(setupCtx.starkInfo.starkStruct.verificationHashType == "GL") { + Goldilocks::Element *pBuffNodesGL = &bufferExt[NExtended * nCols]; + ElementType *pBuffNodes = (ElementType *)pBuffNodesGL; + ElementType *tmpBuffNodes = (ElementType *)(&tmpBuff[(N + NExtended) * nCols]); + memcpy(pBuffNodes, tmpBuffNodes, treesGL[pos]->numNodes * sizeof(ElementType)); + treesGL[pos]->setNodes(pBuffNodes); + } treesGL[pos]->getRoot(&proof.proof.roots[pos - 1][0]); } @@ -74,7 +82,7 @@ void Starks::extendAndMerkelize(uint64_t step, Goldilocks::Element Goldilocks::Element *pBuff = step == 1 ? trace : &aux_trace[setupCtx.starkInfo.mapOffsets[make_pair(section, false)]]; Goldilocks::Element *pBuffExtended = &aux_trace[setupCtx.starkInfo.mapOffsets[make_pair(section, true)]]; - ElementType *pBuffNodes = (ElementType *)(&aux_trace[setupCtx.starkInfo.mapOffsets[make_pair("mt" + to_string(step), true)]]); + NTT_Goldilocks ntt(N); if(pBuffHelper != nullptr) { @@ -84,7 +92,11 @@ void Starks::extendAndMerkelize(uint64_t step, Goldilocks::Element } treesGL[step - 1]->setSource(pBuffExtended); - treesGL[step - 1]->setNodes(pBuffNodes); + if(setupCtx.starkInfo.starkStruct.verificationHashType == "GL") { + Goldilocks::Element *pBuffNodesGL = &aux_trace[setupCtx.starkInfo.mapOffsets[make_pair("mt" + to_string(step), true)]]; + ElementType *pBuffNodes = (ElementType *)pBuffNodesGL; + treesGL[step - 1]->setNodes(pBuffNodes); + } treesGL[step - 1]->merkelize(); treesGL[step - 1]->getRoot(&proof.proof.roots[step - 1][0]); } @@ -140,7 +152,12 @@ void Starks::computeQ(uint64_t step, Goldilocks::Element *buffer, F } treesGL[step - 1]->setSource(&buffer[setupCtx.starkInfo.mapOffsets[std::make_pair("cm" + to_string(step), true)]]); - treesGL[step - 1]->setNodes((ElementType *)(&buffer[setupCtx.starkInfo.mapOffsets[std::make_pair("mt" + to_string(step), true)]])); + if(setupCtx.starkInfo.starkStruct.verificationHashType == "GL") { + Goldilocks::Element *pBuffNodesGL = &buffer[setupCtx.starkInfo.mapOffsets[std::make_pair("mt" + to_string(step), true)]]; + ElementType *pBuffNodes = (ElementType *)pBuffNodesGL; + treesGL[step - 1]->setNodes(pBuffNodes); + } + treesGL[step - 1]->merkelize(); treesGL[step - 1]->getRoot(&proof.proof.roots[step - 1][0]); @@ -338,12 +355,6 @@ void Starks::ffi_treesGL_get_root(uint64_t index, ElementType *dst) treesGL[index]->getRoot(dst); } -template -void Starks::ffi_treesGL_set_root(uint64_t index, FRIProof &proof) -{ - treesGL[index]->getRoot(&proof.proof.roots[index][0]); -} - template void Starks::calculateImPolsExpressions(uint64_t step, StepsParams ¶ms) { std::vector dests; @@ -397,7 +408,10 @@ void Starks::calculateFRIPolynomial(StepsParams ¶ms) { Goldilocks::Element *src = ¶ms.aux_trace[setupCtx.starkInfo.mapOffsets[std::make_pair("fri_" + to_string(step + 1), true)]]; treesFRI[step]->setSource(src); - ElementType *nodes = (ElementType *)(¶ms.aux_trace[setupCtx.starkInfo.mapOffsets[std::make_pair("mt_fri_" + to_string(step + 1), true)]]); - treesFRI[step]->setNodes(nodes); + if(setupCtx.starkInfo.starkStruct.verificationHashType == "GL") { + Goldilocks::Element *pBuffNodesGL = ¶ms.aux_trace[setupCtx.starkInfo.mapOffsets[std::make_pair("mt_fri_" + to_string(step + 1), true)]]; + ElementType *pBuffNodes = (ElementType *)pBuffNodesGL; + treesFRI[step]->setNodes(pBuffNodes); + } } } \ No newline at end of file diff --git a/pil2-stark/src/starkpil/starks.hpp b/pil2-stark/src/starkpil/starks.hpp index 75d5e3ce9..ab33ea3d9 100644 --- a/pil2-stark/src/starkpil/starks.hpp +++ b/pil2-stark/src/starkpil/starks.hpp @@ -34,20 +34,21 @@ class Starks public: Starks(SetupCtx& setupCtx_, Goldilocks::Element *pConstPolsExtendedTreeAddress) : setupCtx(setupCtx_) { + bool allocateNodes = setupCtx.starkInfo.starkStruct.verificationHashType == "GL" ? false : true; treesGL = new MerkleTreeType*[setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size() + 2]; if (pConstPolsExtendedTreeAddress != nullptr) treesGL[setupCtx.starkInfo.nStages + 1] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, pConstPolsExtendedTreeAddress); for (uint64_t i = 0; i < setupCtx.starkInfo.nStages + 1; i++) { std::string section = "cm" + to_string(i + 1); uint64_t nCols = setupCtx.starkInfo.mapSectionsN[section]; - treesGL[i] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, 1 << setupCtx.starkInfo.starkStruct.nBitsExt, nCols); + treesGL[i] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, 1 << setupCtx.starkInfo.starkStruct.nBitsExt, nCols, false, allocateNodes); } for(uint64_t i = 0; i < setupCtx.starkInfo.customCommits.size(); i++) { uint64_t nCols = setupCtx.starkInfo.mapSectionsN[setupCtx.starkInfo.customCommits[i].name + "0"]; - treesGL[setupCtx.starkInfo.nStages + 2 + i] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, 1 << setupCtx.starkInfo.starkStruct.nBitsExt, nCols); + treesGL[setupCtx.starkInfo.nStages + 2 + i] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, 1 << setupCtx.starkInfo.starkStruct.nBitsExt, nCols, false, allocateNodes); } treesFRI = new MerkleTreeType*[setupCtx.starkInfo.starkStruct.steps.size() - 1]; @@ -55,7 +56,7 @@ class Starks uint64_t nGroups = 1 << setupCtx.starkInfo.starkStruct.steps[step + 1].nBits; uint64_t groupSize = (1 << setupCtx.starkInfo.starkStruct.steps[step].nBits) / nGroups; - treesFRI[step] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, nGroups, groupSize * FIELD_EXTENSION); + treesFRI[step] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, nGroups, groupSize * FIELD_EXTENSION, false, allocateNodes); } }; ~Starks() @@ -97,7 +98,6 @@ class Starks // Following function are created to be used by the ffi interface void ffi_treesGL_get_root(uint64_t index, ElementType *dst); - void ffi_treesGL_set_root(uint64_t index, FRIProof &proof); void evmap(StepsParams& params, Goldilocks::Element *LEv); }; diff --git a/pil2-stark/src/starkpil/transcript/transcriptGL.cpp b/pil2-stark/src/starkpil/transcript/transcriptGL.cpp index 06e07c234..7230b2d9d 100644 --- a/pil2-stark/src/starkpil/transcript/transcriptGL.cpp +++ b/pil2-stark/src/starkpil/transcript/transcriptGL.cpp @@ -18,7 +18,7 @@ void TranscriptGL::_updateState() Goldilocks::Element inputs[TRANSCRIPT_OUT_SIZE]; std::memcpy(inputs, pending, TRANSCRIPT_PENDING_SIZE * sizeof(Goldilocks::Element)); std::memcpy(&inputs[TRANSCRIPT_PENDING_SIZE], state, TRANSCRIPT_STATE_SIZE * sizeof(Goldilocks::Element)); - PoseidonGoldilocks::hash_full_result(out, inputs); + Poseidon2Goldilocks::hash_full_result(out, inputs); out_cursor = TRANSCRIPT_OUT_SIZE; std::memset(pending, 0, TRANSCRIPT_PENDING_SIZE * sizeof(Goldilocks::Element)); pending_cursor = 0; diff --git a/pil2-stark/src/starkpil/transcript/transcriptGL.hpp b/pil2-stark/src/starkpil/transcript/transcriptGL.hpp index c7454c930..6729c2c3f 100644 --- a/pil2-stark/src/starkpil/transcript/transcriptGL.hpp +++ b/pil2-stark/src/starkpil/transcript/transcriptGL.hpp @@ -3,7 +3,7 @@ #include "goldilocks_base_field.hpp" #include "goldilocks_cubic_extension.hpp" -#include "poseidon_goldilocks.hpp" +#include "poseidon2_goldilocks.hpp" #include "zklog.hpp" #define TRANSCRIPT_STATE_SIZE 4 diff --git a/pil2-stark/src/utils/utils.cpp b/pil2-stark/src/utils/utils.cpp index 7869e2daa..3bdf9d082 100644 --- a/pil2-stark/src/utils/utils.cpp +++ b/pil2-stark/src/utils/utils.cpp @@ -209,17 +209,21 @@ uint64_t fileSize (const string &fileName) } -void loadFileParallel(void* buffer, const string &fileName, uint64_t size) { +bool loadFileParallel(void* buffer, const string &fileName, uint64_t size, bool exit) { // Check file size struct stat sb; if (lstat(fileName.c_str(), &sb) == -1) { zklog.error("loadFileParallel() failed calling lstat() of file " + fileName); - exitProcess(); + if(exit) exitProcess(); + return false; } if ((uint64_t)sb.st_size != size) { - zklog.error("loadFileParallel() found size of file " + fileName + " to be " + to_string(sb.st_size) + " B instead of " + to_string(size) + " B"); - exitProcess(); + if(exit) { + zklog.error("loadFileParallel() found size of file " + fileName + " to be " + to_string(sb.st_size) + " B instead of " + to_string(size) + " B"); + exitProcess(); + } + return false; } // Determine the number of chunks and the size of each chunk @@ -245,6 +249,8 @@ void loadFileParallel(void* buffer, const string &fileName, uint64_t size) { } fclose(file); } + + return true; } void* loadFileParallel(const string &fileName, uint64_t size) { diff --git a/pil2-stark/src/utils/utils.hpp b/pil2-stark/src/utils/utils.hpp index dbf50e917..8723ab5dd 100644 --- a/pil2-stark/src/utils/utils.hpp +++ b/pil2-stark/src/utils/utils.hpp @@ -47,7 +47,7 @@ uint64_t fileSize (const std::string &fileName); // Load file in parallel void * loadFileParallel(const std::string &fileName, uint64_t size); -void loadFileParallel(void *buffer, const std::string &fileName, uint64_t size); +bool loadFileParallel(void *buffer, const std::string &fileName, uint64_t size, bool exit = true); void writeFileParallel(const std::string &fileName, const void* buffer, uint64_t size, uint64_t offset = 0); diff --git a/proofman/Cargo.toml b/proofman/Cargo.toml index 55a4a40b7..9f3d4bec2 100644 --- a/proofman/Cargo.toml +++ b/proofman/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] proofman-common = { path = "../common" } +proofman-macros = { path = "../macros" } proofman-hints = { path = "../hints" } proofman-util = { path = "../util" } stark = { path = "../provers/stark" } @@ -27,3 +28,4 @@ rayon = "1.7" default = [] no_lib_link = ["proofman-starks-lib-c/no_lib_link"] distributed = ["proofman-common/distributed", "dep:mpi"] +debug = ["proofman-macros/debug"] diff --git a/proofman/src/proofman.rs b/proofman/src/proofman.rs index 7e5975f84..f80559b8b 100644 --- a/proofman/src/proofman.rs +++ b/proofman/src/proofman.rs @@ -26,9 +26,7 @@ use crate::{ verify_constraints_proof, verify_proof, }; -use proofman_common::{ - format_bytes, skip_prover_instance, ProofCtx, ProofOptions, ProofType, Prover, SetupCtx, SetupsVadcop, -}; +use proofman_common::{format_bytes, skip_prover_instance, ProofCtx, ProofOptions, Prover, SetupCtx, SetupsVadcop}; use std::os::raw::c_void; @@ -48,15 +46,17 @@ impl ProofMan { witness_lib_path: PathBuf, rom_path: Option, public_inputs_path: Option, + input_data_path: Option, proving_key_path: PathBuf, output_dir_path: PathBuf, options: ProofOptions, ) -> Result<(), Box> { - timer_start_info!(INITIALIZING_PROOFMAN); + timer_start_info!(INITIALIZING_PROOFMAN_1); Self::check_paths( &witness_lib_path, &rom_path, &public_inputs_path, + &input_data_path, &proving_key_path, &output_dir_path, options.verify_constraints, @@ -64,49 +64,102 @@ impl ProofMan { let mut pctx: ProofCtx = ProofCtx::create_ctx(proving_key_path.clone(), options); - let setups = Arc::new(SetupsVadcop::new(&pctx.global_info, pctx.options.aggregation, pctx.options.final_snark)); - let sctx: Arc = setups.sctx.clone(); + let setups = Arc::new(SetupsVadcop::::new( + &pctx.global_info, + pctx.options.verify_constraints, + pctx.options.aggregation, + pctx.options.final_snark, + )); + let sctx: Arc> = setups.sctx.clone(); pctx.set_weights(&sctx); let pctx = Arc::new(pctx); - timer_stop_and_log_info!(INITIALIZING_PROOFMAN); + timer_stop_and_log_info!(INITIALIZING_PROOFMAN_1); + pctx.dctx_barrier(); timer_start_info!(GENERATING_WITNESS); - let wcm = Arc::new(WitnessManager::new(pctx.clone(), sctx.clone(), rom_path, public_inputs_path)); + let wcm = + Arc::new(WitnessManager::new(pctx.clone(), sctx.clone(), rom_path, public_inputs_path, input_data_path)); Self::initialize_witness(witness_lib_path, wcm.clone())?; - wcm.calculate_witness(1); + pctx.dctx_barrier(); - Self::initialize_fixed_pols(setups.clone(), pctx.clone(), true); + Self::print_summary_info(pctx.clone(), setups.clone()); - let mpi_rank = pctx.dctx_get_rank(); - let n_processes = pctx.dctx_get_n_processes(); + pctx.dctx_assign_instances(); + + Self::initialize_fixed_pols(setups.clone(), pctx.clone()); + pctx.dctx_barrier(); + + wcm.calculate_witness(1); pctx.dctx_close(); - if n_processes > 1 { - let (average_weight, max_weight, min_weight, max_deviation) = pctx.dctx_load_balance_info(); - log::info!( - "{}: Load balance. Average: {} max: {} min: {} deviation: {}", - Self::MY_NAME, - average_weight, - max_weight, - min_weight, - max_deviation - ); - } + pctx.dctx_barrier(); - if mpi_rank == 0 { - Self::print_global_summary(pctx.clone(), setups.sctx.clone()); - } + timer_stop_and_log_info!(GENERATING_WITNESS); - if n_processes > 1 { - Self::print_summary(pctx.clone(), setups.sctx.clone()); + #[cfg(feature = "debug")] + { + let air_instances = pctx.air_instance_repo.air_instances.read().unwrap(); + let instances = pctx.dctx_get_instances(); + let my_instances = pctx.dctx_get_my_instances(); + let mut missing_initialization = false; + for instance_id in my_instances.iter() { + let (airgroup_id, air_id) = instances[*instance_id]; + let air_instance = air_instances.get(instance_id).unwrap(); + let air_instance_id = pctx.dctx_find_air_instance_id(*instance_id); + let air_name = pctx.global_info.airs[airgroup_id][air_id].clone().name; + let setup = setups.sctx.get_setup(airgroup_id, air_id); + let cm_pols_map = setup.stark_info.cm_pols_map.as_ref().unwrap(); + let n_cols = *setup.stark_info.map_sections_n.get("cm1").unwrap() as usize; + + let len = air_instance.trace.len(); + let vals = unsafe { std::slice::from_raw_parts(air_instance.get_trace_ptr() as *mut u64, len) }; + + for (pos, val) in vals.iter().enumerate() { + if *val == u64::MAX - 1 { + let row = pos / n_cols; + let col_id = pos % n_cols; + let col = cm_pols_map.get(col_id).unwrap(); + let col_name = if !col.lengths.is_empty() { + let lengths = col.lengths.iter().map(|l| format!("[{}]", l)).collect::(); + &format!("{}{}", col.name, lengths) + } else { + &col.name + }; + log::warn!( + "{}: Missing initialization {} at row {} of {} in instance {}", + Self::MY_NAME, + col_name, + row, + air_name, + air_instance_id, + ); + missing_initialization = true; + } + } + } + if missing_initialization { + return Err("Missing initialization".into()); + } else { + log::info!("{}: Witness Initialization is done properly", Self::MY_NAME); + return Ok(()); + } } - timer_stop_and_log_info!(GENERATING_WITNESS); + if !pctx.options.verify_constraints { + timer_start_info!(INITIALIZING_PROOFMAN_2); + + Self::initialize_fixed_pols_tree(setups.clone(), pctx.clone()); + pctx.dctx_barrier(); + Self::write_fixed_pols_tree(setups.clone(), pctx.clone()); + pctx.dctx_barrier(); + + timer_stop_and_log_info!(INITIALIZING_PROOFMAN_2); + } timer_start_info!(GENERATING_VADCOP_PROOF); @@ -156,10 +209,9 @@ impl ProofMan { } } - wcm.end_proof(); + wcm.debug(); if pctx.options.verify_constraints { - wcm.debug(); return verify_constraints_proof(pctx.clone(), sctx.clone(), &mut provers); } @@ -167,6 +219,9 @@ impl ProofMan { std::thread::spawn(move || { pctx_.free_traces(); }); + std::thread::spawn(move || { + drop(wcm); + }); // Compute Quotient polynomial Self::get_challenges(num_commit_stages + 1, &mut provers, pctx.clone(), &transcript); @@ -177,6 +232,7 @@ impl ProofMan { // Compute openings Self::opening_stages(&mut provers, pctx.clone(), sctx.clone(), &mut transcript); + //pctx.dctx_barrier(); timer_stop_and_log_info!(GENERATING_PROOF); //Generate proves_out @@ -197,10 +253,22 @@ impl ProofMan { } } + // let pctx_aggregation: Arc> = Arc::new(ProofCtx::create_ctx_agg( + // &pctx.global_info, + // pctx.options.clone(), + // pctx.get_publics().clone(), + // pctx.get_challenges().clone(), + // pctx.get_proof_values().clone(), + // pctx.dctx.read().unwrap().clone(), + // pctx.weights.clone(), + // )); + + let pctx_aggregation = pctx.clone(); + info!("{}: ··· Generating aggregated proofs", Self::MY_NAME); let (circom_witness_size, publics_size, trace_size, prover_buffer_size) = - get_buff_sizes(pctx.clone(), setups.clone())?; + get_buff_sizes(pctx_aggregation.clone(), setups.clone())?; let mut circom_witness: Vec = create_buffer_fast(circom_witness_size); let publics: Vec = create_buffer_fast(publics_size); let trace: Vec = create_buffer_fast(trace_size); @@ -209,7 +277,7 @@ impl ProofMan { timer_start_info!(GENERATING_AGGREGATION_PROOFS); timer_start_info!(GENERATING_COMPRESSOR_AND_RECURSIVE1_PROOFS); let recursive1_proofs = generate_vadcop_recursive1_proof( - &pctx, + &pctx_aggregation, setups.clone(), &proves_out, &mut circom_witness, @@ -221,11 +289,11 @@ impl ProofMan { timer_stop_and_log_info!(GENERATING_COMPRESSOR_AND_RECURSIVE1_PROOFS); info!("{}: Compressor and recursive1 proofs generated successfully", Self::MY_NAME); - pctx.dctx.read().unwrap().barrier(); + pctx_aggregation.dctx.read().unwrap().barrier(); timer_start_info!(GENERATING_RECURSIVE2_PROOFS); let sctx_recursive2 = setups.sctx_recursive2.clone(); let recursive2_proof = generate_vadcop_recursive2_proof( - &pctx, + &pctx_aggregation, sctx_recursive2.as_ref().unwrap().clone(), &recursive1_proofs, &mut circom_witness, @@ -237,12 +305,13 @@ impl ProofMan { timer_stop_and_log_info!(GENERATING_RECURSIVE2_PROOFS); info!("{}: Recursive2 proofs generated successfully", Self::MY_NAME); - pctx.dctx.read().unwrap().barrier(); + pctx_aggregation.dctx.read().unwrap().barrier(); + let mpi_rank = pctx.dctx_get_rank(); if mpi_rank == 0 { let setup_final = setups.setup_vadcop_final.as_ref().unwrap().clone(); timer_start_info!(GENERATING_VADCOP_FINAL_PROOF); let final_proof = generate_vadcop_final_proof( - &pctx, + &pctx_aggregation, setup_final.clone(), recursive2_proof, &mut circom_witness, @@ -256,10 +325,10 @@ impl ProofMan { timer_stop_and_log_info!(GENERATING_AGGREGATION_PROOFS); - if pctx.options.final_snark { + if pctx_aggregation.options.final_snark { timer_start_info!(GENERATING_RECURSIVE_F_PROOF); let recursivef_proof = generate_recursivef_proof( - &pctx, + &pctx_aggregation, setups.setup_recursivef.as_ref().unwrap().clone(), final_proof, &mut circom_witness, @@ -271,10 +340,10 @@ impl ProofMan { timer_stop_and_log_info!(GENERATING_RECURSIVE_F_PROOF); timer_start_info!(GENERATING_FFLONK_SNARK_PROOF); - let _ = generate_fflonk_snark_proof(&pctx, recursivef_proof, output_dir_path.clone()); + let _ = generate_fflonk_snark_proof(&pctx_aggregation, recursivef_proof, output_dir_path.clone()); timer_stop_and_log_info!(GENERATING_FFLONK_SNARK_PROOF); } else { - let setup_path = pctx.global_info.get_setup_path("vadcop_final"); + let setup_path = pctx_aggregation.global_info.get_setup_path("vadcop_final"); let stark_info_path = setup_path.display().to_string() + ".starkinfo.json"; let expressions_bin_path = setup_path.display().to_string() + ".verifier.bin"; let verkey_path = setup_path.display().to_string() + ".verkey.json"; @@ -285,7 +354,7 @@ impl ProofMan { stark_info_path, expressions_bin_path, verkey_path, - Some(pctx.get_publics().clone()), + Some(pctx_aggregation.get_publics().clone()), None, None, ); @@ -306,9 +375,10 @@ impl ProofMan { } } } + pctx_aggregation.dctx_barrier(); timer_stop_and_log_info!(GENERATING_VADCOP_PROOF); info!("{}: Proofs generated successfully", Self::MY_NAME); - pctx.dctx.read().unwrap().barrier(); + pctx_aggregation.dctx.read().unwrap().barrier(); Ok(()) } @@ -333,15 +403,14 @@ impl ProofMan { Ok(()) } - fn initialize_provers(sctx: Arc, provers: &mut Vec>>, pctx: Arc>) { + fn initialize_provers(sctx: Arc>, provers: &mut Vec>>, pctx: Arc>) { timer_start_debug!(INITIALIZE_PROVERS); let instances = pctx.dctx_get_instances(); let my_instances = pctx.dctx_get_my_instances(); for instance_id in my_instances.iter() { let (airgroup_id, air_id) = instances[*instance_id]; let air_instance_id = pctx.dctx_find_air_instance_id(*instance_id); - let (skip, constraints_skip) = - skip_prover_instance(pctx.options.clone(), airgroup_id, air_id, air_instance_id); + let (skip, constraints_skip) = skip_prover_instance(&pctx, *instance_id); if skip { continue; }; @@ -376,7 +445,88 @@ impl ProofMan { timer_stop_and_log_debug!(INITIALIZE_PROVERS); } - fn initialize_fixed_pols(setups: Arc, pctx: Arc>, save_file: bool) { + fn initialize_fixed_pols(setups: Arc>, pctx: Arc>) { + info!("{}: Initializing setup fixed pols", Self::MY_NAME); + timer_start_info!(INITIALIZE_CONST_POLS); + + let instances = pctx.dctx_get_instances(); + let my_instances = pctx.dctx_get_my_instances(); + + let mut airs = Vec::new(); + let mut seen = HashSet::new(); + + for instance_id in my_instances.iter() { + let (airgroup_id, air_id) = instances[*instance_id]; + if seen.insert((airgroup_id, air_id)) { + airs.push((airgroup_id, air_id)); + } + } + + airs.iter().for_each(|&(airgroup_id, air_id)| { + let setup = setups.sctx.get_setup(airgroup_id, air_id); + setup.load_const_pols(); + }); + + timer_stop_and_log_info!(INITIALIZE_CONST_POLS); + + if pctx.options.aggregation { + timer_start_info!(INITIALIZE_CONST_POLS_AGGREGATION); + + info!("{}: Initializing setup fixed pols aggregation", Self::MY_NAME); + + let global_info = pctx.global_info.clone(); + + let sctx_compressor = setups.sctx_compressor.as_ref().unwrap().clone(); + info!("{}: ··· Initializing setup fixed pols compressor", Self::MY_NAME); + timer_start_trace!(INITIALIZE_CONST_POLS_COMPRESSOR); + + airs.iter().for_each(|&(airgroup_id, air_id)| { + if global_info.get_air_has_compressor(airgroup_id, air_id) { + let setup = sctx_compressor.get_setup(airgroup_id, air_id); + setup.load_const_pols(); + } + }); + timer_stop_and_log_trace!(INITIALIZE_CONST_POLS_COMPRESSOR); + + let sctx_recursive1 = setups.sctx_recursive1.as_ref().unwrap().clone(); + timer_start_trace!(INITIALIZE_CONST_POLS_RECURSIVE1); + info!("{}: ··· Initializing setup fixed pols recursive1", Self::MY_NAME); + airs.iter().for_each(|&(airgroup_id, air_id)| { + let setup = sctx_recursive1.get_setup(airgroup_id, air_id); + setup.load_const_pols(); + }); + timer_stop_and_log_trace!(INITIALIZE_CONST_POLS_RECURSIVE1); + + let sctx_recursive2 = setups.sctx_recursive2.as_ref().unwrap().clone(); + timer_start_trace!(INITIALIZE_CONST_POLS_RECURSIVE2); + info!("{}: ··· Initializing setup fixed pols recursive2", Self::MY_NAME); + let n_airgroups = global_info.air_groups.len(); + for airgroup in 0..n_airgroups { + let setup = sctx_recursive2.get_setup(airgroup, 0); + setup.load_const_pols(); + } + timer_stop_and_log_trace!(INITIALIZE_CONST_POLS_RECURSIVE2); + + if pctx.dctx_get_rank() == 0 { + let setup_vadcop_final = setups.setup_vadcop_final.as_ref().unwrap().clone(); + timer_start_trace!(INITIALIZE_CONST_POLS_VADCOP_FINAL); + info!("{}: ··· Initializing setup fixed pols vadcop final", Self::MY_NAME); + setup_vadcop_final.load_const_pols(); + timer_stop_and_log_trace!(INITIALIZE_CONST_POLS_VADCOP_FINAL); + + if pctx.options.final_snark { + let setup_recursivef = setups.setup_recursivef.as_ref().unwrap().clone(); + timer_start_trace!(INITIALIZE_CONST_POLS_RECURSIVE_FINAL); + info!("{}: ··· Initializing setup fixed pols recursive final", Self::MY_NAME); + setup_recursivef.load_const_pols(); + timer_stop_and_log_trace!(INITIALIZE_CONST_POLS_RECURSIVE_FINAL); + } + } + timer_stop_and_log_info!(INITIALIZE_CONST_POLS_AGGREGATION); + } + } + + fn initialize_fixed_pols_tree(setups: Arc>, pctx: Arc>) { info!("{}: Initializing setup fixed pols", Self::MY_NAME); timer_start_info!(INITIALIZE_CONST_POLS); @@ -395,8 +545,7 @@ impl ProofMan { airs.iter().for_each(|&(airgroup_id, air_id)| { let setup = setups.sctx.get_setup(airgroup_id, air_id); - setup.load_const_pols(&pctx.global_info, &ProofType::Basic); - setup.load_const_pols_tree(&pctx.global_info, &ProofType::Basic, save_file); + setup.load_const_pols_tree(); }); timer_stop_and_log_info!(INITIALIZE_CONST_POLS); @@ -415,8 +564,7 @@ impl ProofMan { airs.iter().for_each(|&(airgroup_id, air_id)| { if global_info.get_air_has_compressor(airgroup_id, air_id) { let setup = sctx_compressor.get_setup(airgroup_id, air_id); - setup.load_const_pols(&global_info, &ProofType::Compressor); - setup.load_const_pols_tree(&global_info, &ProofType::Compressor, save_file); + setup.load_const_pols_tree(); } }); timer_stop_and_log_trace!(INITIALIZE_CONST_POLS_COMPRESSOR); @@ -426,8 +574,7 @@ impl ProofMan { info!("{}: ··· Initializing setup fixed pols recursive1", Self::MY_NAME); airs.iter().for_each(|&(airgroup_id, air_id)| { let setup = sctx_recursive1.get_setup(airgroup_id, air_id); - setup.load_const_pols(&global_info, &ProofType::Recursive1); - setup.load_const_pols_tree(&global_info, &ProofType::Recursive1, save_file); + setup.load_const_pols_tree(); }); timer_stop_and_log_trace!(INITIALIZE_CONST_POLS_RECURSIVE1); @@ -437,8 +584,7 @@ impl ProofMan { let n_airgroups = global_info.air_groups.len(); for airgroup in 0..n_airgroups { let setup = sctx_recursive2.get_setup(airgroup, 0); - setup.load_const_pols(&global_info, &ProofType::Recursive2); - setup.load_const_pols_tree(&global_info, &ProofType::Recursive2, save_file); + setup.load_const_pols_tree(); } timer_stop_and_log_trace!(INITIALIZE_CONST_POLS_RECURSIVE2); @@ -446,17 +592,14 @@ impl ProofMan { let setup_vadcop_final = setups.setup_vadcop_final.as_ref().unwrap().clone(); timer_start_trace!(INITIALIZE_CONST_POLS_VADCOP_FINAL); info!("{}: ··· Initializing setup fixed pols vadcop final", Self::MY_NAME); - setup_vadcop_final.load_const_pols(&global_info, &ProofType::VadcopFinal); - setup_vadcop_final.load_const_pols_tree(&global_info, &ProofType::VadcopFinal, save_file); + setup_vadcop_final.load_const_pols_tree(); timer_stop_and_log_trace!(INITIALIZE_CONST_POLS_VADCOP_FINAL); if pctx.options.final_snark { - let global_info = pctx.global_info.clone(); let setup_recursivef = setups.setup_recursivef.as_ref().unwrap().clone(); timer_start_trace!(INITIALIZE_CONST_POLS_RECURSIVE_FINAL); info!("{}: ··· Initializing setup fixed pols recursive final", Self::MY_NAME); - setup_recursivef.load_const_pols(&global_info, &ProofType::RecursiveF); - setup_recursivef.load_const_pols_tree(&global_info, &ProofType::RecursiveF, save_file); + setup_recursivef.load_const_pols_tree(); timer_stop_and_log_trace!(INITIALIZE_CONST_POLS_RECURSIVE_FINAL); } } @@ -464,10 +607,78 @@ impl ProofMan { } } + fn write_fixed_pols_tree(setups: Arc>, pctx: Arc>) { + timer_start_info!(WRITE_CONST_TREE); + let instances = pctx.dctx_get_instances(); + let my_instances = pctx.dctx_get_my_instances(); + + let mut airs = Vec::new(); + let mut seen = HashSet::new(); + + for instance_id in my_instances.iter() { + let (airgroup_id, air_id) = instances[*instance_id]; + if seen.insert((airgroup_id, air_id)) { + airs.push((airgroup_id, air_id)); + } + } + + airs.iter().for_each(|&(airgroup_id, air_id)| { + let setup = setups.sctx.get_setup(airgroup_id, air_id); + if setup.to_write_tree() { + setup.write_const_tree(); + } + }); + + if pctx.options.aggregation { + let global_info = pctx.global_info.clone(); + let sctx_compressor = setups.sctx_compressor.as_ref().unwrap().clone(); + airs.iter().for_each(|&(airgroup_id, air_id)| { + if global_info.get_air_has_compressor(airgroup_id, air_id) { + let setup = sctx_compressor.get_setup(airgroup_id, air_id); + if pctx.dctx_is_min_rank_owner(airgroup_id, air_id) && setup.to_write_tree() { + setup.write_const_tree(); + } + } + }); + let sctx_recursive1 = setups.sctx_recursive1.as_ref().unwrap().clone(); + airs.iter().for_each(|&(airgroup_id, air_id)| { + let setup = sctx_recursive1.get_setup(airgroup_id, air_id); + if pctx.dctx_is_min_rank_owner(airgroup_id, air_id) && setup.to_write_tree() { + setup.write_const_tree(); + } + }); + + if pctx.dctx_get_rank() == 0 { + let sctx_recursive2 = setups.sctx_recursive2.as_ref().unwrap().clone(); + let n_airgroups = global_info.air_groups.len(); + for airgroup in 0..n_airgroups { + let setup = sctx_recursive2.get_setup(airgroup, 0); + if pctx.dctx_is_min_rank_owner(airgroup, 0) && setup.to_write_tree() { + setup.write_const_tree(); + } + } + + let setup_vadcop_final = setups.setup_vadcop_final.as_ref().unwrap().clone(); + if setup_vadcop_final.to_write_tree() { + setup_vadcop_final.write_const_tree(); + } + + if pctx.options.final_snark { + let setup_recursivef = setups.setup_recursivef.as_ref().unwrap().clone(); + if setup_recursivef.to_write_tree() { + setup_recursivef.write_const_tree(); + } + } + } + } + + timer_stop_and_log_info!(WRITE_CONST_TREE); + } + pub fn calculate_stage( stage: u32, provers: &mut [Box>], - sctx: Arc, + sctx: Arc>, pctx: Arc>, ) { if stage as usize == pctx.global_info.n_challenges.len() + 1 { @@ -596,14 +807,14 @@ impl ProofMan { provers: &mut [Box>], pctx: Arc>, transcript: &FFITranscript, - ) { - provers[0].get_challenges(stage, pctx, transcript); // Any prover can get the challenges which are common among them + ) -> Vec> { + provers[0].get_challenges(stage, pctx, transcript) } pub fn opening_stages( provers: &mut [Box>], pctx: Arc>, - sctx: Arc, + sctx: Arc>, transcript: &mut FFITranscript, ) { @@ -611,9 +822,11 @@ impl ProofMan { // Calculate evals timer_start_debug!(CALCULATING_EVALS); - Self::get_challenges(pctx.global_info.n_challenges.len() as u32 + 2, provers, pctx.clone(), transcript); + let challenges = + Self::get_challenges(pctx.global_info.n_challenges.len() as u32 + 2, provers, pctx.clone(), transcript); + let xi_challenge = challenges[0].clone(); for group_idx in pctx.dctx_get_my_air_groups() { - provers[group_idx[0]].calculate_lev(pctx.clone()); + provers[group_idx[0]].calculate_lev(pctx.clone(), xi_challenge.clone()); for idx in group_idx.iter() { provers[*idx].opening_stage(1, sctx.clone(), pctx.clone()); } @@ -627,7 +840,7 @@ impl ProofMan { info!("{}: Calculating FRI Polynomials", Self::MY_NAME); timer_start_info!(CALCULATING_FRI_POLINOMIAL); for group_idx in pctx.dctx_get_my_air_groups().iter() { - provers[group_idx[0]].calculate_xdivxsub(pctx.clone()); + provers[group_idx[0]].calculate_xdivxsub(pctx.clone(), xi_challenge.clone()); for idx in group_idx.iter() { provers[*idx].opening_stage(2, sctx.clone(), pctx.clone()); } @@ -746,36 +959,104 @@ impl ProofMan { pctx.options.aggregation, ); - if pctx.options.aggregation { - std::thread::spawn(move || { - pctx.free_instances(); - }); - } else { - pctx.free_instances(); - } + // if pctx.options.aggregation { + // std::thread::spawn(move || { + // pctx.free_instances(); + // }); + // } else { + // pctx.free_instances(); + // } timer_stop_and_log_info!(FREE_PROVERS); } - fn print_global_summary(pctx: Arc>, sctx: Arc) { + pub fn print_summary_info(pctx: Arc>, setups: Arc>) { + let mpi_rank = pctx.dctx_get_rank(); + let n_processes = pctx.dctx_get_n_processes(); + + if n_processes > 1 { + let (average_weight, max_weight, min_weight, max_deviation) = pctx.dctx_load_balance_info(); + log::info!( + "{}: Load balance. Average: {} max: {} min: {} deviation: {}", + Self::MY_NAME, + average_weight, + max_weight, + min_weight, + max_deviation + ); + } + + if mpi_rank == 0 { + Self::print_summary(pctx.clone(), setups.clone(), true); + } + + if n_processes > 1 { + Self::print_summary(pctx.clone(), setups.clone(), false); + } + } + + pub fn print_summary(pctx: Arc>, setups: Arc>, global: bool) { let mut air_info = HashMap::new(); let mut air_instances = HashMap::new(); let instances = pctx.dctx_get_instances(); - for (airgroup_id, air_id) in &instances { + let mut print = vec![global; instances.len()]; + + if !global { + let my_instances = pctx.dctx_get_my_instances(); + for instance_id in my_instances.iter() { + print[*instance_id] = true; + } + } + + for (instance_id, (airgroup_id, air_id)) in instances.iter().enumerate() { + if !print[instance_id] { + continue; + } let air_name = pctx.global_info.airs[*airgroup_id][*air_id].clone().name; let air_group_name = pctx.global_info.air_groups[*airgroup_id].clone(); let air_instance_map = air_instances.entry(air_group_name).or_insert_with(HashMap::new); if !air_instance_map.contains_key(&air_name.clone()) { - let setup = sctx.get_setup(*airgroup_id, *air_id); + let setup = setups.sctx.get_setup(*airgroup_id, *air_id); let n_bits = setup.stark_info.stark_struct.n_bits; let memory_instance = setup.prover_buffer_size as f64 * 8.0; - let memory_fixed = (setup.stark_info.n_constants * (1 << (setup.stark_info.stark_struct.n_bits)) - + setup.stark_info.n_constants * (1 << (setup.stark_info.stark_struct.n_bits_ext))) - as f64 - * 8.0; + let mut memory_fixed = + (setup.stark_info.n_constants * (1 << (setup.stark_info.stark_struct.n_bits))) as f64; + if !pctx.options.verify_constraints { + memory_fixed += (setup.stark_info.n_constants * (1 << (setup.stark_info.stark_struct.n_bits_ext)) + + (1 << (setup.stark_info.stark_struct.n_bits_ext)) + + ((2 * (1 << (setup.stark_info.stark_struct.n_bits_ext)) - 1) * 4)) + as f64; + } + memory_fixed *= 8.0; + let mut memory_fixed_aggregation = 0f64; + if pctx.options.aggregation { + if pctx.global_info.get_air_has_compressor(*airgroup_id, *air_id) { + let setup_compressor = + setups.sctx_compressor.as_ref().unwrap().get_setup(*airgroup_id, *air_id); + memory_fixed_aggregation += (setup_compressor.stark_info.n_constants + * (1 << (setup_compressor.stark_info.stark_struct.n_bits)) + + setup_compressor.stark_info.n_constants + * (1 << (setup_compressor.stark_info.stark_struct.n_bits_ext)) + + (1 << (setup_compressor.stark_info.stark_struct.n_bits_ext)) + + ((2 * (1 << (setup_compressor.stark_info.stark_struct.n_bits_ext)) - 1) * 4)) + as f64 + * 8.0; + } + + let setup_recursive1 = setups.sctx_recursive1.as_ref().unwrap().get_setup(*airgroup_id, *air_id); + memory_fixed_aggregation += (setup_recursive1.stark_info.n_constants + * (1 << (setup_recursive1.stark_info.stark_struct.n_bits)) + + setup_recursive1.stark_info.n_constants + * (1 << (setup_recursive1.stark_info.stark_struct.n_bits_ext)) + + (1 << (setup_recursive1.stark_info.stark_struct.n_bits_ext)) + + ((2 * (1 << (setup_recursive1.stark_info.stark_struct.n_bits_ext)) - 1) * 4)) + as f64 + * 8.0; + } + let memory_helpers = setup.stark_info.get_buff_helper_size() as f64 * 8.0; let total_cols: u64 = setup .stark_info @@ -784,10 +1065,9 @@ impl ProofMan { .filter(|(key, _)| *key != "const") .map(|(_, value)| *value) .sum(); - let cols_witness: u64 = setup.stark_info.map_sections_n["cm1"]; air_info.insert( air_name.clone(), - (n_bits, total_cols, cols_witness, memory_fixed, memory_helpers, memory_instance), + (n_bits, total_cols, memory_fixed, memory_fixed_aggregation, memory_helpers, memory_instance), ); } let air_instance_map_key = air_instance_map.entry(air_name).or_insert(0); @@ -804,8 +1084,6 @@ impl ProofMan { .bold() ); info!("{}: ► {} Air instances found:", Self::MY_NAME, instances.len()); - let mut total_cells = 0f64; - let mut total_cells_witness = 0f64; for air_group in air_groups.clone() { let air_group_instances = air_instances.get(air_group).unwrap(); let mut air_names: Vec<_> = air_group_instances.keys().collect(); @@ -814,9 +1092,7 @@ impl ProofMan { info!("{}: Air Group [{}]", Self::MY_NAME, air_group); for air_name in air_names { let count = air_group_instances.get(air_name).unwrap(); - let (n_bits, total_cols, cols_witness, _, _, _) = air_info.get(air_name).unwrap(); - total_cells += *total_cols as f64 * *count as f64 * (1 << *n_bits) as f64; - total_cells_witness += *cols_witness as f64 * *count as f64 * (1 << *n_bits) as f64; + let (n_bits, total_cols, _, _, _, _) = air_info.get(air_name).unwrap(); info!( "{}: {}", Self::MY_NAME, @@ -824,113 +1100,67 @@ impl ProofMan { ); } } - info!("{} TOTAL CELLS WITNESS: {} and TOTAL CELLS {}", Self::MY_NAME, total_cells_witness, total_cells); info!("{}: ----------------------------------------------------------", Self::MY_NAME); info!( "{}", - format!("{}: --- TOTAL PROVER MEMORY USAGE ----------------------------", Self::MY_NAME) + format!("{}: --- TOTAL SETUP MEMORY USAGE ----------------------------", Self::MY_NAME) .bright_white() .bold() ); let mut total_memory = 0f64; - let mut memory_helper_size = 0f64; - for air_group in air_groups { + for air_group in air_groups.clone() { let air_group_instances = air_instances.get(air_group).unwrap(); let mut air_names: Vec<_> = air_group_instances.keys().collect(); air_names.sort(); for air_name in air_names { - let count = air_group_instances.get(air_name).unwrap(); - let (_, _, _, memory_fixed, memory_helper_instance_size, memory_instance) = - air_info.get(air_name).unwrap(); - let total_memory_instance = memory_fixed + memory_instance * *count as f64; - total_memory += total_memory_instance; - if *memory_helper_instance_size > memory_helper_size { - memory_helper_size = *memory_helper_instance_size; - } - info!( - "{}: {}", - Self::MY_NAME, - format!( - "· {}: {} fixed cols | {} per each of {} instance | Total {}", - air_name, - format_bytes(*memory_fixed), - format_bytes(*memory_instance), - count, - format_bytes(total_memory_instance) - ) - ); - } - } - total_memory += memory_helper_size; - info!("{}: {}", Self::MY_NAME, format!("Extra helper memory: {}", format_bytes(memory_helper_size))); - info!( - "{}: {}", - Self::MY_NAME, - format!("Total prover memory required: {}", format_bytes(total_memory)).bright_white().bold() - ); - info!("{}: ----------------------------------------------------------", Self::MY_NAME); - } - - fn print_summary(pctx: Arc>, sctx: Arc) { - let mut air_info = HashMap::new(); + let (_, _, memory_fixed, memory_fixed_aggregation, _, _) = air_info.get(air_name).unwrap(); + total_memory += memory_fixed; - let mut air_instances = HashMap::new(); - - let instances = pctx.dctx_get_instances(); - let my_instances = pctx.dctx_get_my_instances(); - - for instance_id in my_instances.iter() { - let (airgroup_id, air_id) = instances[*instance_id]; - let air_name = pctx.global_info.airs[airgroup_id][air_id].clone().name; - let air_group_name = pctx.global_info.air_groups[airgroup_id].clone(); - let air_instance_map = air_instances.entry(air_group_name).or_insert_with(HashMap::new); - if !air_instance_map.contains_key(&air_name.clone()) { - let setup = sctx.get_setup(airgroup_id, air_id); - let n_bits = setup.stark_info.stark_struct.n_bits; - let memory_instance = setup.prover_buffer_size as f64 * 8.0; - let memory_fixed = (setup.stark_info.n_constants * (1 << (setup.stark_info.stark_struct.n_bits)) - + setup.stark_info.n_constants * (1 << (setup.stark_info.stark_struct.n_bits_ext))) - as f64 - * 8.0; - let memory_helpers = setup.stark_info.get_buff_helper_size() as f64 * 8.0; - let total_cols: u64 = setup - .stark_info - .map_sections_n - .iter() - .filter(|(key, _)| *key != "const") - .map(|(_, value)| *value) - .sum(); - air_info.insert(air_name.clone(), (n_bits, total_cols, memory_fixed, memory_helpers, memory_instance)); + if !pctx.options.aggregation { + info!( + "{}: {}", + Self::MY_NAME, + format!("· {}: {} fixed cols", air_name, format_bytes(*memory_fixed),) + ); + } else { + total_memory += memory_fixed_aggregation; + info!( + "{}: {}", + Self::MY_NAME, + format!( + "· {}: {} fixed cols | {} fixed cols aggregation | Total: {}", + air_name, + format_bytes(*memory_fixed), + format_bytes(*memory_fixed_aggregation), + format_bytes(*memory_fixed + *memory_fixed_aggregation), + ) + ); + } } - let air_instance_map_key = air_instance_map.entry(air_name).or_insert(0); - *air_instance_map_key += 1; + info!( + "{}: {}", + Self::MY_NAME, + format!("Total setup memory required: {}", format_bytes(total_memory)).bright_white().bold() + ); } - let mut air_groups: Vec<_> = air_instances.keys().collect(); - air_groups.sort(); - - info!("{}: --- PROOF INSTANCES SUMMARY ------------------------", Self::MY_NAME); - info!("{}: ► {} Air instances found:", Self::MY_NAME, my_instances.len()); - for air_group in air_groups.clone() { - let air_group_instances = air_instances.get(air_group).unwrap(); - let mut air_names: Vec<_> = air_group_instances.keys().collect(); - air_names.sort(); - - info!("{}: Air Group [{}]", Self::MY_NAME, air_group); - for air_name in air_names { - let count = air_group_instances.get(air_name).unwrap(); - let (n_bits, total_cols, _, _, _) = air_info.get(air_name).unwrap(); - info!( - "{}: {}", - Self::MY_NAME, - format!("· {} x Air [{}] ({} x 2^{})", count, air_name, total_cols, n_bits).bright_white().bold() - ); - } + info!("{}: ----------------------------------------------------------", Self::MY_NAME); + if pctx.options.verify_constraints { + info!( + "{}", + format!("{}: --- TOTAL CONSTRAINT CHECKER MEMORY USAGE ----------------------------", Self::MY_NAME) + .bright_white() + .bold() + ); + } else { + info!( + "{}", + format!("{}: --- TOTAL PROVER MEMORY USAGE ----------------------------", Self::MY_NAME) + .bright_white() + .bold() + ); } - info!("{}: ------------------------------------------------", Self::MY_NAME); - info!("{}: --- PROVER MEMORY USAGE ------------------------", Self::MY_NAME); - info!("{}: ► {} Air instances found:", Self::MY_NAME, my_instances.len()); let mut total_memory = 0f64; let mut memory_helper_size = 0f64; for air_group in air_groups { @@ -940,9 +1170,8 @@ impl ProofMan { for air_name in air_names { let count = air_group_instances.get(air_name).unwrap(); - let (_, _, memory_fixed, memory_helper_instance_size, memory_instance) = - air_info.get(air_name).unwrap(); - let total_memory_instance = memory_fixed + memory_instance * *count as f64; + let (_, _, _, _, memory_helper_instance_size, memory_instance) = air_info.get(air_name).unwrap(); + let total_memory_instance = memory_instance * *count as f64; total_memory += total_memory_instance; if *memory_helper_instance_size > memory_helper_size { memory_helper_size = *memory_helper_instance_size; @@ -951,9 +1180,8 @@ impl ProofMan { "{}: {}", Self::MY_NAME, format!( - "· {}: {} fixed cols | {} per each of {} instance | Total {}", + "· {}: {} per each of {} instance | Total {}", air_name, - format_bytes(*memory_fixed), format_bytes(*memory_instance), count, format_bytes(total_memory_instance) @@ -968,13 +1196,14 @@ impl ProofMan { Self::MY_NAME, format!("Total prover memory required: {}", format_bytes(total_memory)).bright_white().bold() ); - info!("{}: ------------------------------------------------", Self::MY_NAME); + info!("{}: ----------------------------------------------------------", Self::MY_NAME); } fn check_paths( witness_lib_path: &PathBuf, rom_path: &Option, public_inputs_path: &Option, + input_data_path: &Option, proving_key_path: &PathBuf, output_dir_path: &PathBuf, verify_constraints: bool, @@ -998,6 +1227,12 @@ impl ProofMan { } } + if let Some(input_path) = input_data_path { + if !input_path.exists() { + return Err(format!("Input data file not found at path: {:?}", input_path).into()); + } + } + // Check proving_key_path exists if !proving_key_path.exists() { return Err(format!("Proving key folder not found at path: {:?}", proving_key_path).into()); diff --git a/proofman/src/recursion.rs b/proofman/src/recursion.rs index fbd965e6e..ac86632d2 100644 --- a/proofman/src/recursion.rs +++ b/proofman/src/recursion.rs @@ -21,7 +21,7 @@ type GetSizeWitnessFunc = unsafe extern "C" fn() -> u64; #[allow(clippy::too_many_arguments)] pub fn generate_vadcop_recursive1_proof( pctx: &ProofCtx, - setups: Arc, + setups: Arc>, proofs: &[*mut c_void], circom_witness: &mut [F], publics: &[F], @@ -81,6 +81,8 @@ pub fn generate_vadcop_recursive1_proof( &proof_file, global_info_file, airgroup_id as u64, + air_id as u64, + air_instance_id as u64, true, ); @@ -129,6 +131,8 @@ pub fn generate_vadcop_recursive1_proof( &proof_file, global_info_file, airgroup_id as u64, + air_id as u64, + air_instance_id as u64, true, ); proofs_out.push(p_prove); @@ -143,7 +147,7 @@ pub fn generate_vadcop_recursive1_proof( #[allow(clippy::too_many_arguments)] pub fn generate_vadcop_recursive2_proof( pctx: &ProofCtx, - sctx: Arc, + sctx: Arc>, proofs: &[*mut c_void], circom_witness: &mut [F], publics: &[F], @@ -274,6 +278,8 @@ pub fn generate_vadcop_recursive2_proof( &proof_file, global_info_file, airgroup as u64, + 0, + 0, true, ); @@ -330,7 +336,7 @@ pub fn generate_vadcop_recursive2_proof( #[allow(clippy::too_many_arguments)] pub fn generate_vadcop_final_proof( pctx: &ProofCtx, - setup: Arc, + setup: Arc>, proof: *mut c_void, circom_witness: &mut [F], publics: &[F], @@ -364,6 +370,8 @@ pub fn generate_vadcop_final_proof( &proof_file, global_info_file, 0, + 0, + 0, false, ); log::info!("{}: ··· Vadcop final Proof generated.", MY_NAME); @@ -375,7 +383,7 @@ pub fn generate_vadcop_final_proof( #[allow(clippy::too_many_arguments)] pub fn generate_recursivef_proof( pctx: &ProofCtx, - setup: Arc, + setup: Arc>, proof: *mut c_void, circom_witness: &mut [F], publics: &[F], @@ -412,6 +420,8 @@ pub fn generate_recursivef_proof( &proof_file, global_info_file, 0, + 0, + 0, false, ); log::info!("{}: ··· RecursiveF Proof generated.", MY_NAME); @@ -475,7 +485,7 @@ fn generate_witness( buffer: &[F], publics: &[F], setup_path: &Path, - setup: &Setup, + setup: &Setup, zkin: *mut c_void, n_cols: usize, ) -> Result<(), Box> { @@ -528,7 +538,7 @@ fn generate_witness( pub fn get_buff_sizes( pctx: Arc>, - setups: Arc, + setups: Arc>, ) -> Result<(usize, usize, usize, usize), Box> { let mut witness_size = 0; let mut publics = 0; @@ -582,7 +592,7 @@ pub fn get_buff_sizes( if pctx.options.final_snark { let setup_recursivef = setups.setup_recursivef.as_ref().unwrap(); let setup_path = pctx.global_info.get_setup_path("recursivef"); - let sizes = get_size(&setup_path, setup_recursivef, 12)?; + let sizes = get_size::(&setup_path, setup_recursivef, 12)?; witness_size = witness_size.max(sizes.0); publics = publics.max(sizes.1); buffer = buffer.max(sizes.2); @@ -592,9 +602,9 @@ pub fn get_buff_sizes( Ok((witness_size, publics, buffer, prover_size as usize)) } -fn get_size( +fn get_size( setup_path: &Path, - setup: &Setup, + setup: &Setup, n_cols: usize, ) -> Result<(usize, usize, usize), Box> { // Load the symbol (function) from the library diff --git a/proofman/src/verify.rs b/proofman/src/verify.rs index 866521389..dbc088181 100644 --- a/proofman/src/verify.rs +++ b/proofman/src/verify.rs @@ -57,7 +57,7 @@ pub fn verify_basic_proofs( provers: &mut [Box>], proves: Vec<*mut c_void>, pctx: Arc>, - sctx: Arc, + sctx: Arc>, ) -> bool { const MY_NAME: &str = "Verify "; timer_start_info!(VERIFYING_BASIC_PROOFS); diff --git a/proofman/src/verify_constraints.rs b/proofman/src/verify_constraints.rs index be344e290..70f685f2d 100644 --- a/proofman/src/verify_constraints.rs +++ b/proofman/src/verify_constraints.rs @@ -15,7 +15,7 @@ use colored::*; pub fn verify_global_constraints_proof( pctx: Arc>, - sctx: Arc, + sctx: Arc>, airgroupvalues: Vec>, ) -> Vec { const MY_NAME: &str = "GlCstVfy"; @@ -51,7 +51,7 @@ pub fn verify_global_constraints_proof( pub fn verify_constraints_proof( pctx: Arc>, - sctx: Arc, + sctx: Arc>, provers: &mut [Box>], ) -> Result<(), Box> { const MY_NAME: &str = "CstrVrfy"; @@ -73,7 +73,7 @@ pub fn verify_constraints_proof( let (airgroup_id, air_id) = instances[*instance_id]; let air_name = &pctx.global_info.airs[airgroup_id][air_id].name; let air_instance_id = pctx.dctx_find_air_instance_id(*instance_id); - let (skip, _) = skip_prover_instance(pctx.options.clone(), airgroup_id, air_id, air_instance_id); + let (skip, _) = skip_prover_instance(&pctx, *instance_id); if skip { log::info!( "{}", diff --git a/provers/stark/src/stark_prover.rs b/provers/stark/src/stark_prover.rs index 8626abc31..00991b6be 100644 --- a/provers/stark/src/stark_prover.rs +++ b/provers/stark/src/stark_prover.rs @@ -43,7 +43,7 @@ impl StarkProver { const FIELD_EXTENSION: usize = 3; pub fn new( - sctx: Arc, + sctx: Arc>, airgroup_id: usize, air_id: usize, air_instance_id: usize, @@ -65,7 +65,8 @@ impl StarkProver { let p_stark_info = setup.p_setup.p_stark_info; - let p_proof = fri_proof_new_c((&setup.p_setup).into(), air_instance_id as u64); + let p_proof = + fri_proof_new_c((&setup.p_setup).into(), airgroup_id as u64, air_id as u64, air_instance_id as u64); Self { global_idx, @@ -141,9 +142,9 @@ impl Prover for StarkProver { self.stark_info.n_stages } - fn verify_constraints(&self, sctx: Arc, pctx: Arc>) -> Vec { - let mut air_instances = pctx.air_instance_repo.air_instances.write().unwrap(); - let air_instance = air_instances.get_mut(&self.global_idx).unwrap(); + fn verify_constraints(&self, sctx: Arc>, pctx: Arc>) -> Vec { + let air_instances = pctx.air_instance_repo.air_instances.read().unwrap(); + let air_instance = air_instances.get(&self.global_idx).unwrap(); let setup = sctx.get_setup(self.airgroup_id, self.air_id); @@ -181,7 +182,7 @@ impl Prover for StarkProver { constraints_info } - fn calculate_stage(&mut self, stage_id: u32, sctx: Arc, pctx: Arc>) { + fn calculate_stage(&mut self, stage_id: u32, sctx: Arc>, pctx: Arc>) { let mut air_instances = pctx.air_instance_repo.air_instances.write().unwrap(); let air_instance = air_instances.get_mut(&self.global_idx).unwrap(); @@ -280,7 +281,7 @@ impl Prover for StarkProver { } } - fn opening_stage(&mut self, opening_id: u32, sctx: Arc, pctx: Arc>) -> ProverStatus { + fn opening_stage(&mut self, opening_id: u32, sctx: Arc>, pctx: Arc>) -> ProverStatus { let steps_fri: Vec = pctx.global_info.steps_fri.iter().map(|step| step.n_bits).collect(); let last_stage_id = steps_fri.len() as u32 + 3; if opening_id == 1 { @@ -378,38 +379,12 @@ impl Prover for StarkProver { custom_publics } - fn calculate_xdivxsub(&mut self, pctx: Arc>) { - let challenges_guard = pctx.challenges.values.read().unwrap(); - - let challenges_map = self.stark_info.challenges_map.as_ref().unwrap(); - - let mut xi_challenge_index: usize = 0; - for (i, challenge) in challenges_map.iter().enumerate() { - if challenge.stage == (Self::num_stages(self) + 2) as u64 && challenge.stage_id == 0_u64 { - xi_challenge_index = i; - break; - } - } - - let xi_challenge = &(*challenges_guard)[xi_challenge_index * Self::FIELD_EXTENSION] as *const F as *mut c_void; - calculate_xdivxsub_c(self.p_stark, xi_challenge, pctx.get_buff_helper_ptr()); + fn calculate_xdivxsub(&mut self, pctx: Arc>, challenge: Vec) { + calculate_xdivxsub_c(self.p_stark, challenge.as_ptr() as *mut c_void, pctx.get_buff_helper_ptr()); } - fn calculate_lev(&mut self, pctx: Arc>) { - let challenges_guard = pctx.challenges.values.read().unwrap(); - - let challenges_map = self.stark_info.challenges_map.as_ref().unwrap(); - - let mut xi_challenge_index: usize = 0; - for (i, challenge) in challenges_map.iter().enumerate() { - if challenge.stage == (Self::num_stages(self) + 2) as u64 && challenge.stage_id == 0_u64 { - xi_challenge_index = i; - break; - } - } - - let xi_challenge = &(*challenges_guard)[xi_challenge_index * Self::FIELD_EXTENSION] as *const F as *mut c_void; - compute_lev_c(self.p_stark, xi_challenge, pctx.get_buff_helper_ptr()); + fn calculate_lev(&mut self, pctx: Arc>, challenge: Vec) { + compute_lev_c(self.p_stark, challenge.as_ptr() as *mut c_void, pctx.get_buff_helper_ptr()); } fn get_buff_helper_size(&self, _proof_ctx: Arc>) -> usize { @@ -482,8 +457,8 @@ impl Prover for StarkProver { let index = if stage == 1 { self.n_field_elements + j } else { j }; values_hash[index] = root_value; } - let mut air_instances = pctx.air_instance_repo.air_instances.write().unwrap(); - let air_instance = air_instances.get_mut(&self.global_idx).unwrap(); + let air_instances = pctx.air_instance_repo.air_instances.read().unwrap(); + let air_instance = air_instances.get(&self.global_idx).unwrap(); let airvalues_map = self.stark_info.airvalues_map.as_ref().unwrap(); let mut p = 0; let mut count = 0; @@ -518,8 +493,8 @@ impl Prover for StarkProver { treesGL_get_root_c(p_stark, stage - 1, value.as_mut_ptr() as *mut u8); } } else if stage == (Self::num_stages(self) + 2) as u64 { - let mut air_instances = pctx.air_instance_repo.air_instances.write().unwrap(); - let air_instance = air_instances.get_mut(&self.global_idx).unwrap(); + let air_instances = pctx.air_instance_repo.air_instances.read().unwrap(); + let air_instance = air_instances.get(&self.global_idx).unwrap(); calculate_hash_c( p_stark, value.as_mut_ptr() as *mut u8, @@ -541,8 +516,8 @@ impl Prover for StarkProver { let p_proof = self.p_proof; fri_proof_get_tree_root_c(p_proof, value.as_mut_ptr() as *mut u8, step_index as u64); } else { - let mut air_instances = pctx.air_instance_repo.air_instances.write().unwrap(); - let air_instance = air_instances.get_mut(&self.global_idx).unwrap(); + let air_instances = pctx.air_instance_repo.air_instances.read().unwrap(); + let air_instance = air_instances.get(&self.global_idx).unwrap(); let n_hash = (1 << (steps[n_steps].n_bits)) * Self::FIELD_EXTENSION as u64; let fri_pol = get_fri_pol_c(self.p_stark_info, air_instance.get_aux_trace_ptr()); calculate_hash_c(p_stark, value.as_mut_ptr() as *mut u8, fri_pol as *mut u8, n_hash); @@ -556,11 +531,13 @@ impl Prover for StarkProver { value64 } - fn get_challenges(&self, stage_id: u32, pctx: Arc>, transcript: &FFITranscript) { + fn get_challenges(&self, stage_id: u32, pctx: Arc>, transcript: &FFITranscript) -> Vec> { if stage_id == 1 { - return; + return Vec::new(); } + let mut challenges_calculated = Vec::new(); + let mpi_rank = pctx.dctx_get_rank(); if stage_id <= self.num_stages() + 3 { @@ -582,6 +559,11 @@ impl Prover for StarkProver { challenges[i * Self::FIELD_EXTENSION + 2], ); } + challenges_calculated.push(vec![ + challenges[i * Self::FIELD_EXTENSION], + challenges[i * Self::FIELD_EXTENSION + 1], + challenges[i * Self::FIELD_EXTENSION + 2], + ]); } } } else { @@ -599,7 +581,13 @@ impl Prover for StarkProver { challenges_guard[challenges_guard.len() - 1], ); } + challenges_calculated.push(vec![ + challenges_guard[challenges_guard.len() - 3], + challenges_guard[challenges_guard.len() - 2], + challenges_guard[challenges_guard.len() - 1], + ]); } + challenges_calculated } fn get_proof(&self) -> *mut c_void { @@ -659,7 +647,7 @@ impl Prover for StarkProver { } impl StarkProver { - fn compute_evals(&mut self, _opening_id: u32, sctx: Arc, pctx: Arc>) { + fn compute_evals(&mut self, _opening_id: u32, sctx: Arc>, pctx: Arc>) { let air_name = &pctx.global_info.airs[self.airgroup_id][self.air_id].name; debug!("{}: ··· Calculating evals of instance {} of {}", Self::MY_NAME, self.air_instance_id, air_name); let mut air_instances = pctx.air_instance_repo.air_instances.write().unwrap(); @@ -689,7 +677,7 @@ impl StarkProver { compute_evals_c(p_stark, (&steps_params).into(), pctx.get_buff_helper_ptr(), p_proof); } - fn compute_fri_pol(&mut self, _opening_id: u32, sctx: Arc, pctx: Arc>) { + fn compute_fri_pol(&mut self, _opening_id: u32, sctx: Arc>, pctx: Arc>) { let air_name = &pctx.global_info.airs[self.airgroup_id][self.air_id].name; debug!( "{}: ··· Calculating FRI polynomial of instance {} of {}", diff --git a/provers/starks-lib-c/bindings_starks.rs b/provers/starks-lib-c/bindings_starks.rs index d0130ef18..8f7513fc2 100644 --- a/provers/starks-lib-c/bindings_starks.rs +++ b/provers/starks-lib-c/bindings_starks.rs @@ -23,8 +23,13 @@ extern "C" { ); } extern "C" { - #[link_name = "\u{1}_Z13fri_proof_newPvm"] - pub fn fri_proof_new(pSetupCtx: *mut ::std::os::raw::c_void, instanceId: u64) -> *mut ::std::os::raw::c_void; + #[link_name = "\u{1}_Z13fri_proof_newPvmmm"] + pub fn fri_proof_new( + pSetupCtx: *mut ::std::os::raw::c_void, + airgroupId: u64, + airId: u64, + instanceId: u64, + ) -> *mut ::std::os::raw::c_void; } extern "C" { #[link_name = "\u{1}_Z23fri_proof_get_tree_rootPvS_m"] @@ -120,12 +125,14 @@ extern "C" { pub fn prover_helpers_free(pProverHelpers: *mut ::std::os::raw::c_void); } extern "C" { - #[link_name = "\u{1}_Z15load_const_treePvPcm"] + #[link_name = "\u{1}_Z15load_const_treePvS_PcmS0_"] pub fn load_const_tree( + pStarkInfo: *mut ::std::os::raw::c_void, pConstTree: *mut ::std::os::raw::c_void, treeFilename: *mut ::std::os::raw::c_char, constTreeSize: u64, - ); + verkeyFilename: *mut ::std::os::raw::c_char, + ) -> bool; } extern "C" { #[link_name = "\u{1}_Z15load_const_polsPvPcm"] @@ -144,11 +151,18 @@ extern "C" { pub fn get_const_size(pStarkInfo: *mut ::std::os::raw::c_void) -> u64; } extern "C" { - #[link_name = "\u{1}_Z20calculate_const_treePvS_S_Pc"] + #[link_name = "\u{1}_Z20calculate_const_treePvS_S_"] pub fn calculate_const_tree( pStarkInfo: *mut ::std::os::raw::c_void, pConstPolsAddress: *mut ::std::os::raw::c_void, pConstTree: *mut ::std::os::raw::c_void, + ); +} +extern "C" { + #[link_name = "\u{1}_Z16write_const_treePvS_Pc"] + pub fn write_const_tree( + pStarkInfo: *mut ::std::os::raw::c_void, + pConstTreeAddress: *mut ::std::os::raw::c_void, treeFilename: *mut ::std::os::raw::c_char, ); } @@ -280,10 +294,6 @@ extern "C" { #[link_name = "\u{1}_Z16treesGL_get_rootPvmS_"] pub fn treesGL_get_root(pStarks: *mut ::std::os::raw::c_void, index: u64, root: *mut ::std::os::raw::c_void); } -extern "C" { - #[link_name = "\u{1}_Z16treesGL_set_rootPvmS_"] - pub fn treesGL_set_root(pStarks: *mut ::std::os::raw::c_void, index: u64, pProof: *mut ::std::os::raw::c_void); -} extern "C" { #[link_name = "\u{1}_Z18calculate_xdivxsubPvS_S_"] pub fn calculate_xdivxsub( @@ -556,11 +566,13 @@ extern "C" { pub fn print_row(pSetupCtx: *mut ::std::os::raw::c_void, buffer: *mut ::std::os::raw::c_void, stage: u64, row: u64); } extern "C" { - #[link_name = "\u{1}_Z19gen_recursive_proofPvPcmS_S_S_S_S_S0_b"] + #[link_name = "\u{1}_Z19gen_recursive_proofPvPcmmmS_S_S_S_S_S0_b"] pub fn gen_recursive_proof( pSetupCtx: *mut ::std::os::raw::c_void, globalInfoFile: *mut ::std::os::raw::c_char, airgroupId: u64, + airId: u64, + instanceId: u64, witness: *mut ::std::os::raw::c_void, aux_trace: *mut ::std::os::raw::c_void, pConstPols: *mut ::std::os::raw::c_void, @@ -689,3 +701,23 @@ extern "C" { #[link_name = "\u{1}_Z11free_bufferPv"] pub fn free_buffer(buffer: *mut ::std::os::raw::c_void); } + +extern "C" { + #[link_name = "\u{1}_Z20write_fixed_cols_binPcS_S_mmPv"] + pub fn write_fixed_cols_bin( + binFile: *mut ::std::os::raw::c_char, + airgroupName: *mut ::std::os::raw::c_char, + airName: *mut ::std::os::raw::c_char, + N: u64, + nFixedPols: u64, + fixedPolsInfo: *mut ::std::os::raw::c_void, + ); +} +extern "C" { + #[link_name = "\u{1}_Z19get_omp_max_threadsv"] + pub fn get_omp_max_threads() -> u64; +} +extern "C" { + #[link_name = "\u{1}_Z19set_omp_num_threadsm"] + pub fn set_omp_num_threads(num_threads: u64); +} \ No newline at end of file diff --git a/provers/starks-lib-c/src/ffi_starks.rs b/provers/starks-lib-c/src/ffi_starks.rs index dde9f6fdf..2f833275b 100644 --- a/provers/starks-lib-c/src/ffi_starks.rs +++ b/provers/starks-lib-c/src/ffi_starks.rs @@ -55,8 +55,8 @@ pub fn save_proof_values_c(proof_values: *mut u8, global_info_file: &str, output } #[cfg(not(feature = "no_lib_link"))] -pub fn fri_proof_new_c(p_setup_ctx: *mut c_void, instance_id: u64) -> *mut c_void { - unsafe { fri_proof_new(p_setup_ctx, instance_id) } +pub fn fri_proof_new_c(p_setup_ctx: *mut c_void, airgroup_id: u64, air_id: u64, instance_id: u64) -> *mut c_void { + unsafe { fri_proof_new(p_setup_ctx, airgroup_id, air_id, instance_id) } } #[cfg(not(feature = "no_lib_link"))] @@ -207,31 +207,45 @@ pub fn get_const_tree_size_c(pStarkInfo: *mut c_void) -> u64 { } #[cfg(not(feature = "no_lib_link"))] -pub fn load_const_tree_c(pConstPolsTreeAddress: *mut u8, tree_filename: &str, const_tree_size: u64) { +pub fn load_const_tree_c( + pStarkInfo: *mut c_void, + pConstPolsTreeAddress: *mut u8, + tree_filename: &str, + const_tree_size: u64, + verkey_filename: &str, +) -> bool { unsafe { let tree_filename: CString = CString::new(tree_filename).unwrap(); + let verkey_filename: CString = CString::new(verkey_filename).unwrap(); load_const_tree( + pStarkInfo, pConstPolsTreeAddress as *mut std::os::raw::c_void, tree_filename.as_ptr() as *mut std::os::raw::c_char, const_tree_size, + verkey_filename.as_ptr() as *mut std::os::raw::c_char, + ) + } +} + +#[cfg(not(feature = "no_lib_link"))] +pub fn calculate_const_tree_c(pStarkInfo: *mut c_void, pConstPols: *mut u8, pConstPolsTreeAddress: *mut u8) { + unsafe { + calculate_const_tree( + pStarkInfo, + pConstPols as *mut std::os::raw::c_void, + pConstPolsTreeAddress as *mut std::os::raw::c_void, ); } } #[cfg(not(feature = "no_lib_link"))] -pub fn calculate_const_tree_c( - pStarkInfo: *mut c_void, - pConstPols: *mut u8, - pConstPolsTreeAddress: *mut u8, - tree_filename: &str, -) { +pub fn write_const_tree_c(pStarkInfo: *mut c_void, pConstPolsTreeAddress: *mut u8, tree_filename: &str) { unsafe { let tree_filename: CString = CString::new(tree_filename).unwrap(); - calculate_const_tree( + write_const_tree( pStarkInfo, - pConstPols as *mut std::os::raw::c_void, pConstPolsTreeAddress as *mut std::os::raw::c_void, tree_filename.as_ptr() as *mut std::os::raw::c_char, ); @@ -494,13 +508,6 @@ pub fn treesGL_get_root_c(pStark: *mut c_void, index: u64, root: *mut u8) { } } -#[cfg(not(feature = "no_lib_link"))] -pub fn treesGL_set_root_c(pStark: *mut c_void, index: u64, pProof: *mut c_void) { - unsafe { - treesGL_set_root(pStark, index, pProof); - } -} - #[cfg(not(feature = "no_lib_link"))] pub fn calculate_xdivxsub_c(p_stark: *mut c_void, xi_challenge: *mut c_void, xdivxsub: *mut u8) { unsafe { @@ -931,6 +938,8 @@ pub fn gen_recursive_proof_c( proof_file: &str, global_info_file: &str, airgroup_id: u64, + air_id: u64, + instance_id: u64, vadcop: bool, ) -> *mut c_void { let proof_file_name = CString::new(proof_file).unwrap(); @@ -944,6 +953,8 @@ pub fn gen_recursive_proof_c( p_setup_ctx, global_info_file_ptr, airgroup_id, + air_id, + instance_id, p_witness as *mut std::os::raw::c_void, p_aux_trace as *mut std::os::raw::c_void, p_const_pols as *mut std::os::raw::c_void, @@ -1159,6 +1170,39 @@ pub fn free_buffer_c(buffer: *mut u8) { } } +#[cfg(not(feature = "no_lib_link"))] +pub fn write_fixed_cols_bin_c( + binfile: &str, + airgroup: &str, + air: &str, + n: u64, + n_fixed_pols: u64, + fixed_pols_info: *mut c_void, +) { + let binfile_name = CString::new(binfile).unwrap(); + let binfile_name_ptr = binfile_name.as_ptr() as *mut std::os::raw::c_char; + + let airgroup_name = CString::new(airgroup).unwrap(); + let airgroup_name_ptr = airgroup_name.as_ptr() as *mut std::os::raw::c_char; + + let air_name = CString::new(air).unwrap(); + let air_name_ptr = air_name.as_ptr() as *mut std::os::raw::c_char; + unsafe { + write_fixed_cols_bin(binfile_name_ptr, airgroup_name_ptr, air_name_ptr, n, n_fixed_pols, fixed_pols_info); + } +} + +#[cfg(not(feature = "no_lib_link"))] +pub fn get_omp_max_threads_c() -> u64 { + unsafe { get_omp_max_threads() } +} + +#[cfg(not(feature = "no_lib_link"))] +pub fn set_omp_num_threads_c(num_threads: u64) { + unsafe { + set_omp_num_threads(num_threads); + } +} // ------------------------ // MOCK METHODS FOR TESTING // ------------------------ @@ -1178,7 +1222,7 @@ pub fn save_proof_values_c(_proof_values: *mut u8, _global_info_file: &str, _out } #[cfg(feature = "no_lib_link")] -pub fn fri_proof_new_c(_p_setup_ctx: *mut c_void, _instance_id: u64) -> *mut c_void { +pub fn fri_proof_new_c(_p_setup_ctx: *mut c_void, _airgroup_id: u64, _air_id: u64, _instance_id: u64) -> *mut c_void { trace!("{}: ··· {}", "ffi ", "fri_proof_new: This is a mock call because there is no linked library"); std::ptr::null_mut() } @@ -1309,20 +1353,27 @@ pub fn get_const_size_c(_pStarkInfo: *mut c_void) -> u64 { } #[cfg(feature = "no_lib_link")] -pub fn load_const_tree_c(_pConstPolsTreeAddress: *mut u8, _tree_filename: &str, _const_tree_size: u64) { +pub fn load_const_tree_c( + _pStarkInfo: *mut c_void, + _pConstPolsTreeAddress: *mut u8, + _tree_filename: &str, + _const_tree_size: u64, + _verkey_path: &str, +) -> bool { trace!("{}: ··· {}", "ffi ", "load_const_tree: This is a mock call because there is no linked library"); + true } #[cfg(feature = "no_lib_link")] -pub fn calculate_const_tree_c( - _pStarkInfo: *mut c_void, - _pConstPols: *mut u8, - _pConstPolsTreeAddress: *mut u8, - _tree_filename: &str, -) { +pub fn calculate_const_tree_c(_pStarkInfo: *mut c_void, _pConstPols: *mut u8, _pConstPolsTreeAddress: *mut u8) { trace!("{}: ··· {}", "ffi ", "calculate_const_tree: This is a mock call because there is no linked library"); } +#[cfg(feature = "no_lib_link")] +pub fn write_const_tree_c(_pStarkInfo: *mut c_void, _pConstPolsTreeAddress: *mut u8, _tree_filename: &str) { + trace!("{}: ··· {}", "ffi ", "write_const_tree: This is a mock call because there is no linked library"); +} + #[cfg(feature = "no_lib_link")] pub fn expressions_bin_new_c(_filename: &str, _global: bool, _verify: bool) -> *mut c_void { std::ptr::null_mut() @@ -1468,11 +1519,6 @@ pub fn treesGL_get_root_c(_pStark: *mut c_void, _index: u64, _root: *mut u8) { trace!("{}: ··· {}", "ffi ", "treesGL_get_root: This is a mock call because there is no linked library"); } -#[cfg(feature = "no_lib_link")] -pub fn treesGL_set_root_c(_pStark: *mut c_void, _index: u64, _pProof: *mut c_void) { - trace!("{}: ··· {}", "ffi ", "treesGL_set_root: This is a mock call because there is no linked library"); -} - #[cfg(feature = "no_lib_link")] pub fn calculate_fri_polynomial_c(_p_starks: *mut c_void, _p_steps_params: *mut u8) { trace!("mckzkevm: ··· {}", "calculate_fri_polynomial: This is a mock call because there is no linked library"); @@ -1808,6 +1854,8 @@ pub fn gen_recursive_proof_c( _proof_file: &str, _global_info_file: &str, _airgroup_id: u64, + _air_id: u64, + _instance_id: u64, _vadcop: bool, ) -> *mut c_void { trace!("{}: ··· {}", "ffi ", "gen_recursive_proof: This is a mock call because there is no linked library"); @@ -1919,3 +1967,26 @@ pub fn stark_verify_c( trace!("{}: ··· {}", "ffi ", "stark_verify_c: This is a mock call because there is no linked library"); true } + +#[cfg(feature = "no_lib_link")] +pub fn write_fixed_cols_bin_c( + _binfile: &str, + _airgroup: &str, + _air: &str, + _n: u64, + _n_fixed_pols: u64, + _fixed_pols_info: *mut c_void, +) { + trace!("{}: ··· {}", "ffi ", "write_fixed_cols_bi: This is a mock call because there is no linked library"); +} + +#[cfg(feature = "no_lib_link")] +pub fn get_omp_max_threads() -> u64 { + trace!("{}: ··· {}", "ffi ", "get_omp_max_threads: This is a mock call because there is no linked library"); + 1 +} + +#[cfg(feature = "no_lib_link")] +pub fn set_omp_num_threads(_num_threads: u64) { + trace!("{}: ··· {}", "ffi ", "set_omp_num_threads: This is a mock call because there is no linked library"); +} diff --git a/util/src/lib.rs b/util/src/lib.rs index ccc1baa00..157004fcc 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -1,10 +1,9 @@ pub mod cli; pub mod timer_macro; -use p3_field::Field; use std::mem::MaybeUninit; -pub fn create_buffer_fast(buffer_size: usize) -> Vec { +pub fn create_buffer_fast(buffer_size: usize) -> Vec { let mut buffer: Vec> = Vec::with_capacity(buffer_size); unsafe { buffer.set_len(buffer_size); diff --git a/witness/src/witness_component.rs b/witness/src/witness_component.rs index d587fb034..ed17b6fd2 100644 --- a/witness/src/witness_component.rs +++ b/witness/src/witness_component.rs @@ -2,14 +2,12 @@ use std::sync::Arc; use proofman_common::{ProofCtx, SetupCtx}; -pub trait WitnessComponent: Send + Sync { - fn start_proof(&self, _pctx: Arc>, _sctx: Arc) {} +pub trait WitnessComponent: Send + Sync { + fn start_proof(&self, _pctx: Arc>, _sctx: Arc>) {} fn execute(&self, _pctx: Arc>) {} - fn debug(&self, _pctx: Arc>) {} + fn debug(&self, _pctx: Arc>, _sctx: Arc>) {} - fn calculate_witness(&self, _stage: u32, _pctx: Arc>, _sctx: Arc) {} - - fn end_proof(&self) {} + fn calculate_witness(&self, _stage: u32, _pctx: Arc>, _sctx: Arc>) {} } diff --git a/witness/src/witness_library.rs b/witness/src/witness_library.rs index 73f114a5a..996869da9 100644 --- a/witness/src/witness_library.rs +++ b/witness/src/witness_library.rs @@ -6,7 +6,7 @@ use proofman_common::VerboseMode; /// This is the type of the function that is used to load a witness library. pub type WitnessLibInitFn = fn(VerboseMode) -> Result>, Box>; -pub trait WitnessLibrary { +pub trait WitnessLibrary { fn register_witness(&mut self, wcm: Arc>); } diff --git a/witness/src/witness_manager.rs b/witness/src/witness_manager.rs index 4713baf9c..624dbfb47 100644 --- a/witness/src/witness_manager.rs +++ b/witness/src/witness_manager.rs @@ -1,39 +1,57 @@ use std::sync::{Arc, RwLock}; use std::path::PathBuf; -use proofman_common::{ProofCtx, SetupCtx}; +use proofman_common::{ModeName, ProofCtx, SetupCtx}; use proofman_util::{timer_start_info, timer_stop_and_log_info}; use crate::WitnessComponent; -pub struct WitnessManager { +pub struct WitnessManager { components: RwLock>>>, + components_std: RwLock>>>, pctx: Arc>, - sctx: Arc, + sctx: Arc>, rom_path: Option, public_inputs_path: Option, + input_data_path: Option, } -impl WitnessManager { +impl WitnessManager { const MY_NAME: &'static str = "WCMnager"; pub fn new( pctx: Arc>, - sctx: Arc, + sctx: Arc>, rom_path: Option, public_inputs_path: Option, + input_data_path: Option, ) -> Self { - WitnessManager { components: RwLock::new(Vec::new()), pctx, sctx, rom_path, public_inputs_path } + WitnessManager { + components: RwLock::new(Vec::new()), + components_std: RwLock::new(Vec::new()), + pctx, + sctx, + rom_path, + public_inputs_path, + input_data_path, + } } pub fn register_component(&self, component: Arc>) { self.components.write().unwrap().push(component); } + pub fn register_component_std(&self, component: Arc>) { + self.components_std.write().unwrap().push(component); + } + pub fn start_proof(&self) { timer_start_info!(START_PROOF); for component in self.components.read().unwrap().iter() { component.start_proof(self.pctx.clone(), self.sctx.clone()); } + for component in self.components_std.read().unwrap().iter() { + component.start_proof(self.pctx.clone(), self.sctx.clone()); + } timer_stop_and_log_info!(START_PROOF); } @@ -42,18 +60,24 @@ impl WitnessManager { for component in self.components.read().unwrap().iter() { component.execute(self.pctx.clone()); } + for component in self.components_std.read().unwrap().iter() { + component.execute(self.pctx.clone()); + } timer_stop_and_log_info!(EXECUTE); } pub fn debug(&self) { - for component in self.components.read().unwrap().iter() { - component.debug(self.pctx.clone()); + if self.pctx.options.debug_info.std_mode.name == ModeName::Debug + || !self.pctx.options.debug_info.debug_instances.is_empty() + { + for component in self.components.read().unwrap().iter() { + component.debug(self.pctx.clone(), self.sctx.clone()); + } } - } - - pub fn end_proof(&self) { - for component in self.components.read().unwrap().iter() { - component.end_proof(); + if self.pctx.options.debug_info.std_mode.name == ModeName::Debug { + for component in self.components_std.read().unwrap().iter() { + component.debug(self.pctx.clone(), self.sctx.clone()); + } } } @@ -67,11 +91,14 @@ impl WitnessManager { timer_start_info!(CALCULATING_WITNESS); - // Call one time all unused components for component in self.components.read().unwrap().iter() { component.calculate_witness(stage, self.pctx.clone(), self.sctx.clone()); } + for component in self.components_std.read().unwrap().iter() { + component.calculate_witness(stage, self.pctx.clone(), self.sctx.clone()); + } + timer_stop_and_log_info!(CALCULATING_WITNESS); } @@ -79,7 +106,7 @@ impl WitnessManager { self.pctx.clone() } - pub fn get_sctx(&self) -> Arc { + pub fn get_sctx(&self) -> Arc> { self.sctx.clone() } @@ -90,4 +117,8 @@ impl WitnessManager { pub fn get_public_inputs_path(&self) -> Option { self.public_inputs_path.clone() } + + pub fn get_input_data_path(&self) -> Option { + self.input_data_path.clone() + } }