diff --git a/Cargo.lock b/Cargo.lock index aec6ed19..184febe2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -723,6 +723,7 @@ dependencies = [ name = "pil-std-lib" version = "0.1.0" dependencies = [ + "colored", "log", "num-bigint", "num-traits", diff --git a/common/src/std_mode.rs b/common/src/std_mode.rs index e4858083..ca631607 100644 --- a/common/src/std_mode.rs +++ b/common/src/std_mode.rs @@ -10,27 +10,28 @@ pub struct StdMode { pub opids: Vec, pub n_vals: usize, pub print_to_file: bool, + pub fast_mode: bool, } impl StdMode { - pub const fn new(name: ModeName, opids: Vec, n_vals: usize, print_to_file: bool) -> Self { + pub const fn new(name: ModeName, opids: Vec, n_vals: usize, print_to_file: bool, fast_mode: bool) -> Self { if name.as_usize() != ModeName::Standard.as_usize() && n_vals == 0 { panic!("n_vals must be greater than 0"); } - Self { name, opids, n_vals, print_to_file } + Self { name, opids, n_vals, print_to_file, fast_mode } } pub fn new_debug() -> Self { - Self::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false) + Self::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false, true) } } impl From for StdMode { fn from(v: u8) -> Self { match v { - 0 => StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false), - 1 => StdMode::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false), + 0 => StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false, false), + 1 => StdMode::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false, true), _ => panic!("Invalid mode"), } } @@ -38,7 +39,7 @@ impl From for StdMode { impl Default for StdMode { fn default() -> Self { - StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false) + StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false, false) } } diff --git a/common/src/utils.rs b/common/src/utils.rs index 569331c4..8e58c29f 100644 --- a/common/src/utils.rs +++ b/common/src/utils.rs @@ -70,6 +70,9 @@ pub fn skip_prover_instance( (true, Vec::new()) } +fn default_fast_mode() -> bool { + true +} #[derive(Debug, Default, Deserialize)] struct StdDebugMode { #[serde(default)] @@ -78,6 +81,8 @@ struct StdDebugMode { n_print: Option, #[serde(default)] print_to_file: bool, + #[serde(default = "default_fast_mode")] + fast_mode: bool, } #[derive(Debug, Deserialize)] @@ -200,14 +205,18 @@ pub fn json_to_debug_instances_map(proving_key_path: PathBuf, json_path: String) let global_constraints = json.global_constraints.unwrap_or_default(); let std_mode = if !airgroup_map.is_empty() { - StdMode::new(ModeName::Standard, Vec::new(), 0, false) + StdMode::new(ModeName::Standard, Vec::new(), 0, false, false) } else { let mode = json.std_mode.unwrap_or_default(); + let fast_mode = + if mode.opids.is_some() && !mode.opids.as_ref().unwrap().is_empty() { false } else { mode.fast_mode }; + StdMode::new( ModeName::Debug, mode.opids.unwrap_or_default(), mode.n_print.unwrap_or(DEFAULT_PRINT_VALS), mode.print_to_file, + fast_mode, ) }; diff --git a/examples/fibonacci-square/src/fibonacci.rs b/examples/fibonacci-square/src/fibonacci.rs index 2563bdec..30c68e15 100644 --- a/examples/fibonacci-square/src/fibonacci.rs +++ b/examples/fibonacci-square/src/fibonacci.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use proofman_common::{add_air_instance, AirInstance, FromTrace, ProofCtx}; +use proofman_common::{add_air_instance, AirInstance, FromTrace, ProofCtx, SetupCtx}; use witness::WitnessComponent; use p3_field::PrimeField64; @@ -70,7 +70,7 @@ impl WitnessComponent for FibonacciSquare { add_air_instance::(air_instance, pctx.clone()); } - fn debug(&self, _pctx: Arc>) { + fn debug(&self, _pctx: Arc>, _sctx: Arc) { // let trace = FibonacciSquareTrace::from_vec(pctx.get_air_instance_trace(0, 0, 0)); // let air_values = FibonacciSquareAirValues::from_vec(pctx.get_air_instance_air_values(0, 0, 0)); // let airgroup_values = FibonacciSquareAirGroupValues::from_vec(pctx.get_air_instance_airgroup_values(0, 0, 0)); diff --git a/hints/src/global_hints.rs b/hints/src/global_hints.rs index 38c4978f..663cdc9d 100644 --- a/hints/src/global_hints.rs +++ b/hints/src/global_hints.rs @@ -24,7 +24,7 @@ pub fn aggregate_airgroupvals(pctx: Arc>) -> Vec> airgroupvalues.push(values); } - for (_, air_instance) in pctx.air_instance_repo.air_instances.write().unwrap().iter() { + for (_, air_instance) in pctx.air_instance_repo.air_instances.read().unwrap().iter() { for (idx, agg_type) in pctx.global_info.agg_types[air_instance.airgroup_id].iter().enumerate() { let mut acc = ExtensionField { value: [ diff --git a/hints/src/hints.rs b/hints/src/hints.rs index 396ee644..e2ded17c 100644 --- a/hints/src/hints.rs +++ b/hints/src/hints.rs @@ -936,7 +936,7 @@ fn get_hint_f( pctx: Option<&ProofCtx>, airgroup_id: usize, air_id: usize, - air_instance: Option<&mut AirInstance>, + air_instance: Option<&AirInstance>, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, @@ -1001,7 +1001,7 @@ fn get_hint_f( pub fn get_hint_field( sctx: &SetupCtx, pctx: &ProofCtx, - air_instance: &mut AirInstance, + air_instance: &AirInstance, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, @@ -1031,7 +1031,7 @@ pub fn get_hint_field( pub fn get_hint_field_a( sctx: &SetupCtx, pctx: &ProofCtx, - air_instance: &mut AirInstance, + air_instance: &AirInstance, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, @@ -1065,7 +1065,7 @@ pub fn get_hint_field_a( pub fn get_hint_field_m( sctx: &SetupCtx, pctx: &ProofCtx, - air_instance: &mut AirInstance, + air_instance: &AirInstance, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, diff --git a/pil2-components/lib/std/rs/Cargo.toml b/pil2-components/lib/std/rs/Cargo.toml index aa58e2cd..07a7a824 100644 --- a/pil2-components/lib/std/rs/Cargo.toml +++ b/pil2-components/lib/std/rs/Cargo.toml @@ -14,3 +14,4 @@ p3-goldilocks.workspace = true p3-field.workspace = true rayon.workspace = true witness.workspace = true +colored.workspace = true diff --git a/pil2-components/lib/std/rs/src/common.rs b/pil2-components/lib/std/rs/src/common.rs index 7d7884fc..a112e8c8 100644 --- a/pil2-components/lib/std/rs/src/common.rs +++ b/pil2-components/lib/std/rs/src/common.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use p3_field::PrimeField; use num_traits::ToPrimitive; -use proofman_common::{AirInstance, ProofCtx, SetupCtx}; +use proofman_common::{ProofCtx, SetupCtx}; use proofman_hints::{ get_hint_field_constant_gc, get_hint_field_constant, get_hint_field_constant_a, HintFieldOptions, HintFieldOutput, HintFieldValue, @@ -13,17 +13,6 @@ pub trait AirComponent { fn new(pctx: Arc>, sctx: Arc, airgroup_id: Option, air_id: Option) -> Arc; - - fn debug_mode( - &self, - _pctx: &ProofCtx, - _sctx: &SetupCtx, - _air_instance: &mut AirInstance, - _air_instance_id: usize, - _num_rows: usize, - _debug_data_hints: Vec, - ) { - } } // Helper to extract hint fields diff --git a/pil2-components/lib/std/rs/src/debug.rs b/pil2-components/lib/std/rs/src/debug.rs index ad13d611..9a3aff69 100644 --- a/pil2-components/lib/std/rs/src/debug.rs +++ b/pil2-components/lib/std/rs/src/debug.rs @@ -1,16 +1,23 @@ use std::{ collections::HashMap, fs::{self, File}, + hash::{DefaultHasher, Hasher, Hash}, io::{self, Write}, path::{Path, PathBuf}, - sync::Mutex, }; use p3_field::PrimeField; use proofman_common::ProofCtx; use proofman_hints::{format_vec, HintFieldOutput}; -pub type DebugData = Mutex>, BusValue>>>; // opid -> val -> BusValue +use num_bigint::BigUint; +use num_traits::Zero; + +use colored::*; + +pub type DebugData = HashMap>, BusValue>>; // opid -> val -> BusValue + +pub type DebugDataFast = HashMap; // opid -> sharedDataFast #[derive(Debug)] pub struct BusValue { @@ -25,6 +32,14 @@ struct SharedData { num_assumes: F, } +#[derive(Clone, Debug)] +pub struct SharedDataFast { + pub num_proves: BigUint, + pub num_assumes: BigUint, + pub num_proves_global: Vec, + pub num_assumes_global: Vec, +} + type AirGroupMap = HashMap; type AirMap = HashMap; @@ -43,9 +58,63 @@ struct InstanceData { row_assumes: Vec, } +#[allow(clippy::too_many_arguments)] +pub fn update_debug_data_fast( + debug_data_fast: &mut DebugDataFast, + opid: F, + val: Vec>, + proves: bool, + times: F, + is_global: bool, +) { + let bus_opid_times = debug_data_fast.entry(opid).or_insert_with(|| SharedDataFast { + num_assumes_global: Vec::new(), + num_proves_global: Vec::new(), + num_proves: BigUint::zero(), + num_assumes: BigUint::zero(), + }); + + let mut values = Vec::new(); + for value in val.iter() { + match value { + HintFieldOutput::Field(f) => values.push(*f), + HintFieldOutput::FieldExtended(ef) => { + values.push(ef.value[0]); + values.push(ef.value[1]); + values.push(ef.value[2]); + } + } + } + + let mut hasher = DefaultHasher::new(); + values.hash(&mut hasher); + + let hash_value = BigUint::from(hasher.finish()); + + if is_global { + if proves { + // Check if bus op id times num proves global contains value + if bus_opid_times.num_proves_global.contains(&hash_value) { + return; + } + bus_opid_times.num_proves_global.push(hash_value * times.as_canonical_biguint()); + } else { + if bus_opid_times.num_assumes_global.contains(&hash_value) { + return; + } + bus_opid_times.num_assumes_global.push(hash_value); + } + } else if proves { + bus_opid_times.num_proves += hash_value * times.as_canonical_biguint(); + } else { + assert!(times.is_one(), "The selector value is invalid: expected 1, but received {:?}.", times); + bus_opid_times.num_assumes += hash_value; + } +} + #[allow(clippy::too_many_arguments)] pub fn update_debug_data( - debug_data: &DebugData, + debug_data: &mut DebugData, name_piop: &str, name_expr: &[String], opid: F, @@ -58,9 +127,7 @@ pub fn update_debug_data( times: F, is_global: bool, ) { - let mut bus = debug_data.lock().expect("Bus values missing"); - - let bus_opid = bus.entry(opid).or_default(); + let bus_opid = debug_data.entry(opid).or_default(); let bus_val = bus_opid.entry(val).or_insert_with(|| BusValue { shared_data: SharedData { direct_was_called: false, num_proves: F::zero(), num_assumes: F::zero() }, @@ -99,18 +166,77 @@ pub fn update_debug_data( } } +pub fn check_invalid_opids( + _pctx: &ProofCtx, + name: &str, + debugs_data_fasts: &mut [DebugDataFast], +) -> Vec { + let mut debug_data_fast = HashMap::new(); + + let mut global_assumes = Vec::new(); + let mut global_proves = Vec::new(); + for map in debugs_data_fasts { + for (opid, bus) in map.iter() { + if debug_data_fast.contains_key(opid) { + let bus_fast: &mut SharedDataFast = debug_data_fast.get_mut(opid).unwrap(); + for assume_global in bus.num_assumes_global.iter() { + if global_assumes.contains(assume_global) { + continue; + } + global_assumes.push(assume_global.clone()); + bus_fast.num_assumes += assume_global; + } + for prove_global in bus.num_proves_global.iter() { + if global_proves.contains(prove_global) { + continue; + } + global_proves.push(prove_global.clone()); + bus_fast.num_proves += prove_global; + } + + bus_fast.num_proves += bus.num_proves.clone(); + bus_fast.num_assumes += bus.num_assumes.clone(); + } else { + debug_data_fast.insert(*opid, bus.clone()); + } + } + } + + // TODO: SINCRONIZATION IN DISTRIBUTED MODE + + let mut invalid_opids = Vec::new(); + + // Check if there are any invalid opids + + for (opid, bus) in debug_data_fast.iter_mut() { + if bus.num_proves != bus.num_assumes { + invalid_opids.push(*opid); + } + } + + if !invalid_opids.is_empty() { + log::error!( + "{}: ··· {}", + name, + format!("\u{2717} The following opids does not match {:?}", invalid_opids).bright_red().bold() + ); + } else { + log::info!("{}: ··· {}", name, "\u{2713} All bus values match.".bright_green().bold()); + } + + invalid_opids +} pub fn print_debug_info( pctx: &ProofCtx, name: &str, max_values_to_print: usize, print_to_file: bool, - debug_data: &DebugData, + debug_data: &mut DebugData, ) { let mut file_path = PathBuf::new(); let mut output: Box = Box::new(io::stdout()); let mut there_are_errors = false; - let mut bus_vals = debug_data.lock().expect("Bus values missing"); - for (opid, bus) in bus_vals.iter_mut() { + for (opid, bus) in debug_data.iter_mut() { if bus.iter().any(|(_, v)| v.shared_data.num_proves != v.shared_data.num_assumes) { if !there_are_errors { // Print to a file if requested @@ -204,6 +330,10 @@ pub fn print_debug_info( } } + if !there_are_errors { + log::info!("{}: ··· {}", name, "\u{2713} All bus values match.".bright_green().bold()); + } + fn print_diffs( pctx: &ProofCtx, val: &[HintFieldOutput], diff --git a/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs b/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs index f16979a4..1ea4829a 100644 --- a/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs +++ b/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs @@ -273,13 +273,13 @@ impl WitnessComponent for SpecifiedRanges { let buffer = create_buffer_fast(buffer_size as usize); // Add a new air instance. Since Specified Ranges is a table, only this air instance is needed - let mut air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer)); + let air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer)); let mut mul_columns_guard = self.mul_columns.lock().unwrap(); for hint in hints_guard[1..].iter() { mul_columns_guard.push(get_hint_field::( &sctx, &pctx, - &mut air_instance, + &air_instance, hint.to_usize().unwrap(), "reference", HintFieldOptions::dest_with_zeros(), diff --git a/pil2-components/lib/std/rs/src/range_check/u16air.rs b/pil2-components/lib/std/rs/src/range_check/u16air.rs index 8ab33689..4c6abf2f 100644 --- a/pil2-components/lib/std/rs/src/range_check/u16air.rs +++ b/pil2-components/lib/std/rs/src/range_check/u16air.rs @@ -143,12 +143,12 @@ impl WitnessComponent for U16Air { let buffer = create_buffer_fast(buffer_size); // Add a new air instance. Since U16Air is a table, only this air instance is needed - let mut air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer)); + let air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer)); *self.mul_column.lock().unwrap() = get_hint_field::( &sctx, &pctx, - &mut air_instance, + &air_instance, u16air_hints[0] as usize, "reference", HintFieldOptions::dest_with_zeros(), diff --git a/pil2-components/lib/std/rs/src/range_check/u8air.rs b/pil2-components/lib/std/rs/src/range_check/u8air.rs index 8d5a7089..b29145f3 100644 --- a/pil2-components/lib/std/rs/src/range_check/u8air.rs +++ b/pil2-components/lib/std/rs/src/range_check/u8air.rs @@ -142,12 +142,12 @@ impl WitnessComponent for U8Air { let buffer = create_buffer_fast(buffer_size); // Add a new air instance. Since U8Air is a table, only this air instance is needed - let mut air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer)); + let air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer)); *self.mul_column.lock().unwrap() = get_hint_field::( &sctx, &pctx, - &mut air_instance, + &air_instance, u8air_hints[0] as usize, "reference", HintFieldOptions::dest_with_zeros(), diff --git a/pil2-components/lib/std/rs/src/std_prod.rs b/pil2-components/lib/std/rs/src/std_prod.rs index 88c5d19b..cab4cb64 100644 --- a/pil2-components/lib/std/rs/src/std_prod.rs +++ b/pil2-components/lib/std/rs/src/std_prod.rs @@ -3,33 +3,36 @@ use std::{ sync::{Arc, Mutex}, }; +use rayon::prelude::*; + use num_traits::ToPrimitive; use p3_field::PrimeField; +use proofman_util::{timer_start_info, timer_stop_and_log_info}; use witness::WitnessComponent; -use proofman_common::{AirInstance, ModeName, ProofCtx, SetupCtx}; +use proofman_common::{AirInstance, ProofCtx, SetupCtx}; use proofman_hints::{ get_hint_field_gc_constant_a, get_hint_field, get_hint_field_a, acc_mul_hint_fields, update_airgroupvalue, get_hint_ids_by_name, HintFieldOptions, HintFieldValue, HintFieldValuesVec, }; use crate::{ - extract_field_element_as_usize, get_global_hint_field_constant_as, get_hint_field_constant_a_as_string, - get_hint_field_constant_as_field, get_hint_field_constant_as_string, get_row_field_value, print_debug_info, - update_debug_data, AirComponent, DebugData, + check_invalid_opids, extract_field_element_as_usize, get_global_hint_field_constant_as, + get_hint_field_constant_a_as_string, get_hint_field_constant_as_field, get_hint_field_constant_as_string, + get_row_field_value, print_debug_info, update_debug_data, update_debug_data_fast, AirComponent, DebugData, + DebugDataFast, SharedDataFast, }; pub struct StdProd { - pctx: Arc>, stage_wc: Option>, - debug_data: Option>, + _phantom: std::marker::PhantomData, } impl AirComponent for StdProd { const MY_NAME: &'static str = "STD Prod"; fn new( - pctx: Arc>, + _pctx: Arc>, sctx: Arc, _airgroup_id: Option, _air_id: Option, @@ -39,7 +42,6 @@ impl AirComponent for StdProd { // Initialize std_prod with the extracted data Arc::new(Self { - pctx: pctx.clone(), stage_wc: match std_prod_users_id.is_empty() { true => None, false => { @@ -49,24 +51,26 @@ impl AirComponent for StdProd { Some(Mutex::new(stage_wc)) } }, - debug_data: if pctx.options.debug_info.std_mode.name == ModeName::Debug { - Some(Mutex::new(HashMap::new())) - } else { - None - }, + _phantom: std::marker::PhantomData, }) } +} +impl StdProd { + const MY_NAME: &'static str = "STD Prod"; + #[allow(clippy::too_many_arguments)] fn debug_mode( &self, pctx: &ProofCtx, sctx: &SetupCtx, - air_instance: &mut AirInstance, + air_instance: &AirInstance, air_instance_id: usize, num_rows: usize, debug_data_hints: Vec, + debug_data: &mut DebugData, + debug_data_fast: &mut DebugDataFast, + fast_mode: bool, ) { - let debug_data = self.debug_data.as_ref().expect("Debug data missing"); let airgroup_id = air_instance.airgroup_id; let air_id = air_instance.air_id; @@ -168,7 +172,9 @@ impl AirComponent for StdProd { &expressions, 0, debug_data, + debug_data_fast, is_global.is_one(), + fast_mode, ); } // Otherwise, update the bus for each row @@ -186,7 +192,9 @@ impl AirComponent for StdProd { &expressions, j, debug_data, + debug_data_fast, false, + fast_mode, ); } } @@ -203,8 +211,10 @@ impl AirComponent for StdProd { sel: &HintFieldValue, expressions: &HintFieldValuesVec, row: usize, - debug_data: &DebugData, + debug_data: &mut DebugData, + debug_data_fast: &mut DebugDataFast, is_global: bool, + fast_mode: bool, ) { let mut sel = get_row_field_value(sel, row, "sel"); if sel.is_zero() { @@ -223,20 +233,24 @@ impl AirComponent for StdProd { _ => panic!("Proves hint must be either 0, 1, or -1"), }; - update_debug_data( - debug_data, - name_piop, - name_expr, - opid, - expressions.get(row), - airgroup_id, - air_id, - air_instance_id, - row, - proves, - sel, - is_global, - ); + if fast_mode { + update_debug_data_fast(debug_data_fast, opid, expressions.get(row), proves, sel, is_global); + } else { + update_debug_data( + debug_data, + name_piop, + name_expr, + opid, + expressions.get(row), + airgroup_id, + air_id, + air_instance_id, + row, + proves, + sel, + is_global, + ); + } } } } @@ -283,23 +297,7 @@ impl WitnessComponent for StdProd { log::debug!("{}: ··· Computing witness for AIR '{}' at stage {}", Self::MY_NAME, air_name, stage); - let num_rows = pctx.global_info.airs[airgroup_id][air_id].num_rows; - let gprod_hints = get_hint_ids_by_name(p_expressions_bin, "gprod_col"); - let debug_data_hints = get_hint_ids_by_name(p_expressions_bin, "gprod_debug_data"); - - // Debugging, if enabled - if pctx.options.debug_info.std_mode.name == ModeName::Debug { - let air_instance_id = pctx.dctx_find_air_instance_id(global_instance_id); - self.debug_mode( - &pctx, - &sctx, - air_instance, - air_instance_id, - num_rows, - debug_data_hints.clone(), - ); - } // We know that at most one product hint exists let gprod_hint = if gprod_hints.len() > 1 { @@ -344,15 +342,140 @@ impl WitnessComponent for StdProd { } } - fn end_proof(&self) { - // Print debug info if in debug mode - if self.pctx.options.debug_info.std_mode.name == ModeName::Debug { - let pctx = &self.pctx; - let name = Self::MY_NAME; - let max_values_to_print = pctx.options.debug_info.std_mode.n_vals; - let print_to_file = pctx.options.debug_info.std_mode.print_to_file; - let debug_data = self.debug_data.as_ref().expect("Debug data missing"); - print_debug_info(pctx, name, max_values_to_print, print_to_file, debug_data); + fn debug(&self, pctx: Arc>, sctx: Arc) { + timer_start_info!(DEBUG_MODE_PROD); + + let std_prod_users_vec = get_hint_ids_by_name(sctx.get_global_bin(), "std_prod_users"); + + if !std_prod_users_vec.is_empty() { + let std_prod_users = std_prod_users_vec[0]; + + let num_users = get_global_hint_field_constant_as::(sctx.clone(), std_prod_users, "num_users"); + let airgroup_ids = get_hint_field_gc_constant_a::(sctx.clone(), std_prod_users, "airgroup_ids", false); + let air_ids = get_hint_field_gc_constant_a::(sctx.clone(), std_prod_users, "air_ids", false); + + let fast_mode = pctx.options.debug_info.std_mode.fast_mode; + + let mut debug_data = HashMap::new(); + + let mut debugs_data_fasts: Vec> = Vec::new(); + + let mut global_instance_ids = Vec::new(); + + for i in 0..num_users { + let airgroup_id = extract_field_element_as_usize(&airgroup_ids.values[i], "airgroup_id"); + let air_id = extract_field_element_as_usize(&air_ids.values[i], "air_id"); + + // Get all air instances ids for this airgroup and air_id + let global_ids = pctx.air_instance_repo.find_air_instances(airgroup_id, air_id); + + for global_instance_id in global_ids { + // Retrieve all air instances + let air_instances = &mut pctx.air_instance_repo.air_instances.read().unwrap(); + let air_instance = air_instances.get(&global_instance_id).unwrap(); + + if air_instance.prover_initialized { + global_instance_ids.push(global_instance_id); + } + } + } + + if fast_mode { + // Process each sum check user + debugs_data_fasts = global_instance_ids + .par_iter() + .map(|&global_instance_id| { + let mut local_debug_data_fast = HashMap::new(); + + // Retrieve all air instances + let air_instances = &mut pctx.air_instance_repo.air_instances.read().unwrap(); + + let air_instance = air_instances.get(&global_instance_id).unwrap(); + let air_instance_id = pctx.dctx_find_air_instance_id(global_instance_id); + let air_name = &pctx.global_info.airs[air_instance.airgroup_id][air_instance.air_id].name; + + log::debug!( + "{}: ··· Checking debug mode fast for instance_id {} of {}", + Self::MY_NAME, + air_instance_id, + air_name + ); + + // Get the air associated with the air_instance + let airgroup_id = air_instance.airgroup_id; + let air_id = air_instance.air_id; + + let setup = sctx.get_setup(airgroup_id, air_id); + let p_expressions_bin = setup.p_setup.p_expressions_bin; + + let num_rows = pctx.global_info.airs[airgroup_id][air_id].num_rows; + + let debug_data_hints = get_hint_ids_by_name(p_expressions_bin, "gprod_debug_data"); + + self.debug_mode( + &pctx, + &sctx, + air_instance, + air_instance_id, + num_rows, + debug_data_hints.clone(), + &mut HashMap::new(), + &mut local_debug_data_fast, + true, + ); + + local_debug_data_fast + }) + .collect(); + } else { + // Process each sum check user + for global_instance_id in global_instance_ids { + // Retrieve all air instances + let air_instances = &mut pctx.air_instance_repo.air_instances.read().unwrap(); + let air_instance = air_instances.get(&global_instance_id).unwrap(); + let air_instance_id = pctx.dctx_find_air_instance_id(global_instance_id); + let air_name = &pctx.global_info.airs[air_instance.airgroup_id][air_instance.air_id].name; + + log::debug!( + "{}: ··· Checking debug mode for instance_id {} of {}", + Self::MY_NAME, + air_instance_id, + air_name + ); + + // Get the air associated with the air_instance + let airgroup_id = air_instance.airgroup_id; + let air_id = air_instance.air_id; + + let setup = sctx.get_setup(airgroup_id, air_id); + let p_expressions_bin = setup.p_setup.p_expressions_bin; + + let num_rows = pctx.global_info.airs[airgroup_id][air_id].num_rows; + + let debug_data_hints = get_hint_ids_by_name(p_expressions_bin, "gprod_debug_data"); + + self.debug_mode( + &pctx, + &sctx, + air_instance, + air_instance_id, + num_rows, + debug_data_hints.clone(), + &mut debug_data, + &mut HashMap::new(), + false, + ); + } + } + + if fast_mode { + check_invalid_opids(&pctx, Self::MY_NAME, &mut debugs_data_fasts); + } else { + let max_values_to_print = pctx.options.debug_info.std_mode.n_vals; + let print_to_file = pctx.options.debug_info.std_mode.print_to_file; + print_debug_info(&pctx, Self::MY_NAME, max_values_to_print, print_to_file, &mut debug_data); + } } + timer_stop_and_log_info!(DEBUG_MODE_PROD); } } diff --git a/pil2-components/lib/std/rs/src/std_sum.rs b/pil2-components/lib/std/rs/src/std_sum.rs index ef7c5344..51b01273 100644 --- a/pil2-components/lib/std/rs/src/std_sum.rs +++ b/pil2-components/lib/std/rs/src/std_sum.rs @@ -3,33 +3,36 @@ use std::{ sync::{Arc, Mutex}, }; +use rayon::prelude::*; + use num_traits::ToPrimitive; use p3_field::PrimeField; +use proofman_util::{timer_start_info, timer_stop_and_log_info}; use witness::WitnessComponent; -use proofman_common::{AirInstance, ProofCtx, SetupCtx, ModeName}; +use proofman_common::{AirInstance, ProofCtx, SetupCtx}; use proofman_hints::{ get_hint_field_gc_constant_a, get_hint_field, get_hint_field_a, acc_mul_hint_fields, update_airgroupvalue, get_hint_ids_by_name, mul_hint_fields, HintFieldOptions, HintFieldOutput, HintFieldValue, HintFieldValuesVec, }; use crate::{ - extract_field_element_as_usize, get_global_hint_field_constant_as, get_hint_field_constant_a_as_string, - get_hint_field_constant_as_field, get_hint_field_constant_as_string, get_row_field_value, print_debug_info, - update_debug_data, AirComponent, DebugData, + check_invalid_opids, extract_field_element_as_usize, get_global_hint_field_constant_as, + get_hint_field_constant_a_as_string, get_hint_field_constant_as_field, get_hint_field_constant_as_string, + get_row_field_value, print_debug_info, update_debug_data, update_debug_data_fast, AirComponent, DebugData, + DebugDataFast, SharedDataFast, }; pub struct StdSum { - pctx: Arc>, stage_wc: Option>, - debug_data: Option>, + _phantom: std::marker::PhantomData, } impl AirComponent for StdSum { const MY_NAME: &'static str = "STD Sum "; fn new( - pctx: Arc>, + _pctx: Arc>, sctx: Arc, _airgroup_id: Option, _air_id: Option, @@ -39,7 +42,6 @@ impl AirComponent for StdSum { // Initialize std_sum with the extracted data Arc::new(Self { - pctx: pctx.clone(), stage_wc: match std_sum_users_id.is_empty() { true => None, false => { @@ -49,24 +51,25 @@ impl AirComponent for StdSum { Some(Mutex::new(stage_wc)) } }, - debug_data: if pctx.options.debug_info.std_mode.name == ModeName::Debug { - Some(Mutex::new(HashMap::new())) - } else { - None - }, + _phantom: std::marker::PhantomData, }) } +} +impl StdSum { + #[allow(clippy::too_many_arguments)] fn debug_mode( &self, pctx: &ProofCtx, sctx: &SetupCtx, - air_instance: &mut AirInstance, + air_instance: &AirInstance, air_instance_id: usize, num_rows: usize, debug_data_hints: Vec, + debug_data: &mut DebugData, + debug_data_fast: &mut DebugDataFast, + fast_mode: bool, ) { - let debug_data = self.debug_data.as_ref().expect("Debug data missing"); let airgroup_id = air_instance.airgroup_id; let air_id = air_instance.air_id; @@ -165,7 +168,9 @@ impl AirComponent for StdSum { &expressions, 0, debug_data, + debug_data_fast, is_global.is_one(), + fast_mode, ); } // Otherwise, update the bus for each row @@ -200,7 +205,9 @@ impl AirComponent for StdSum { &expressions, j, debug_data, + debug_data_fast, false, + fast_mode, ); } } @@ -218,8 +225,10 @@ impl AirComponent for StdSum { mul: &HintFieldValue, expressions: &HintFieldValuesVec, row: usize, - debug_data: &DebugData, + debug_data: &mut DebugData, + debug_data_fast: &mut DebugDataFast, is_global: bool, + fast_mode: bool, ) { let mut mul = get_row_field_value(mul, row, "mul"); if mul.is_zero() { @@ -238,20 +247,24 @@ impl AirComponent for StdSum { _ => panic!("Proves hint must be either 0, 1, or -1"), }; - update_debug_data( - debug_data, - name_piop, - name_expr, - opid, - expressions.get(row), - airgroup_id, - air_id, - instance_id, - row, - proves, - mul, - is_global, - ); + if fast_mode { + update_debug_data_fast(debug_data_fast, opid, expressions.get(row), proves, mul, is_global); + } else { + update_debug_data( + debug_data, + name_piop, + name_expr, + opid, + expressions.get(row), + airgroup_id, + air_id, + instance_id, + row, + proves, + mul, + is_global, + ); + } } } } @@ -297,24 +310,8 @@ impl WitnessComponent for StdSum { log::debug!("{}: ··· Computing witness for AIR '{}' at stage {}", Self::MY_NAME, air_name, stage); - let num_rows = pctx.global_info.airs[airgroup_id][air_id].num_rows; - let im_hints = get_hint_ids_by_name(p_expressions_bin, "im_col"); let gsum_hints = get_hint_ids_by_name(p_expressions_bin, "gsum_col"); - let debug_data_hints = get_hint_ids_by_name(p_expressions_bin, "gsum_debug_data"); - - // Debugging, if enabled - if pctx.options.debug_info.std_mode.name == ModeName::Debug { - let air_instance_id = pctx.dctx_find_air_instance_id(global_instance_id); - self.debug_mode( - &pctx, - &sctx, - air_instance, - air_instance_id, - num_rows, - debug_data_hints.clone(), - ); - } // Populate the im columns for hint in im_hints { @@ -374,15 +371,139 @@ impl WitnessComponent for StdSum { } } - fn end_proof(&self) { - // Print debug info if in debug mode - if self.pctx.options.debug_info.std_mode.name == ModeName::Debug { - let pctx = &self.pctx; - let name = Self::MY_NAME; - let max_values_to_print = pctx.options.debug_info.std_mode.n_vals; - let print_to_file = pctx.options.debug_info.std_mode.print_to_file; - let debug_data = self.debug_data.as_ref().expect("Debug data missing"); - print_debug_info(pctx, name, max_values_to_print, print_to_file, debug_data); + fn debug(&self, pctx: Arc>, sctx: Arc) { + timer_start_info!(DEBUG_MODE_SUM); + let std_sum_users_vec = get_hint_ids_by_name(sctx.get_global_bin(), "std_sum_users"); + + if !std_sum_users_vec.is_empty() { + let std_sum_users = std_sum_users_vec[0]; + + let num_users = get_global_hint_field_constant_as::(sctx.clone(), std_sum_users, "num_users"); + let airgroup_ids = get_hint_field_gc_constant_a::(sctx.clone(), std_sum_users, "airgroup_ids", false); + let air_ids = get_hint_field_gc_constant_a::(sctx.clone(), std_sum_users, "air_ids", false); + + let fast_mode = pctx.options.debug_info.std_mode.fast_mode; + + let mut debug_data = HashMap::new(); + + let mut debugs_data_fasts: Vec> = Vec::new(); + + let mut global_instance_ids = Vec::new(); + + for i in 0..num_users { + let airgroup_id = extract_field_element_as_usize(&airgroup_ids.values[i], "airgroup_id"); + let air_id = extract_field_element_as_usize(&air_ids.values[i], "air_id"); + + // Get all air instances ids for this airgroup and air_id + let global_ids = pctx.air_instance_repo.find_air_instances(airgroup_id, air_id); + + for global_instance_id in global_ids { + // Retrieve all air instances + let air_instances = &mut pctx.air_instance_repo.air_instances.read().unwrap(); + let air_instance = air_instances.get(&global_instance_id).unwrap(); + + if air_instance.prover_initialized { + global_instance_ids.push(global_instance_id); + } + } + } + + if fast_mode { + // Process each sum check user + debugs_data_fasts = global_instance_ids + .par_iter() + .map(|&global_instance_id| { + let mut local_debug_data_fast = HashMap::new(); + + // Retrieve all air instances + let air_instances = &mut pctx.air_instance_repo.air_instances.read().unwrap(); + + let air_instance = air_instances.get(&global_instance_id).unwrap(); + let air_instance_id = pctx.dctx_find_air_instance_id(global_instance_id); + let air_name = &pctx.global_info.airs[air_instance.airgroup_id][air_instance.air_id].name; + + log::debug!( + "{}: ··· Checking debug mode fast for instance_id {} of {}", + Self::MY_NAME, + air_instance_id, + air_name + ); + + // Get the air associated with the air_instance + let airgroup_id = air_instance.airgroup_id; + let air_id = air_instance.air_id; + + let setup = sctx.get_setup(airgroup_id, air_id); + let p_expressions_bin = setup.p_setup.p_expressions_bin; + + let num_rows = pctx.global_info.airs[airgroup_id][air_id].num_rows; + + let debug_data_hints = get_hint_ids_by_name(p_expressions_bin, "gsum_debug_data"); + + self.debug_mode( + &pctx, + &sctx, + air_instance, + air_instance_id, + num_rows, + debug_data_hints.clone(), + &mut HashMap::new(), + &mut local_debug_data_fast, + true, + ); + + local_debug_data_fast + }) + .collect(); + } else { + // Process each sum check user + for global_instance_id in global_instance_ids { + // Retrieve all air instances + let air_instances = &mut pctx.air_instance_repo.air_instances.read().unwrap(); + let air_instance = air_instances.get(&global_instance_id).unwrap(); + let air_instance_id = pctx.dctx_find_air_instance_id(global_instance_id); + let air_name = &pctx.global_info.airs[air_instance.airgroup_id][air_instance.air_id].name; + + log::debug!( + "{}: ··· Checking debug mode for instance_id {} of {}", + Self::MY_NAME, + air_instance_id, + air_name + ); + + // Get the air associated with the air_instance + let airgroup_id = air_instance.airgroup_id; + let air_id = air_instance.air_id; + + let setup = sctx.get_setup(airgroup_id, air_id); + let p_expressions_bin = setup.p_setup.p_expressions_bin; + + let num_rows = pctx.global_info.airs[airgroup_id][air_id].num_rows; + + let debug_data_hints = get_hint_ids_by_name(p_expressions_bin, "gsum_debug_data"); + + self.debug_mode( + &pctx, + &sctx, + air_instance, + air_instance_id, + num_rows, + debug_data_hints.clone(), + &mut debug_data, + &mut HashMap::new(), + false, + ); + } + } + + if fast_mode { + check_invalid_opids(&pctx, Self::MY_NAME, &mut debugs_data_fasts); + } else { + let max_values_to_print = pctx.options.debug_info.std_mode.n_vals; + let print_to_file = pctx.options.debug_info.std_mode.print_to_file; + print_debug_info(&pctx, Self::MY_NAME, max_values_to_print, print_to_file, &mut debug_data); + } } + timer_stop_and_log_info!(DEBUG_MODE_SUM); } } diff --git a/proofman/src/proofman.rs b/proofman/src/proofman.rs index 1520ea39..94ab02d8 100644 --- a/proofman/src/proofman.rs +++ b/proofman/src/proofman.rs @@ -162,10 +162,9 @@ impl ProofMan { } } - wcm.end_proof(); + wcm.debug(); if pctx.options.verify_constraints { - wcm.debug(); return verify_constraints_proof(pctx.clone(), sctx.clone(), &mut provers); } diff --git a/witness/src/witness_component.rs b/witness/src/witness_component.rs index d587fb03..269925c0 100644 --- a/witness/src/witness_component.rs +++ b/witness/src/witness_component.rs @@ -7,9 +7,7 @@ pub trait WitnessComponent: Send + Sync { fn execute(&self, _pctx: Arc>) {} - fn debug(&self, _pctx: Arc>) {} + fn debug(&self, _pctx: Arc>, _sctx: Arc) {} fn calculate_witness(&self, _stage: u32, _pctx: Arc>, _sctx: Arc) {} - - fn end_proof(&self) {} } diff --git a/witness/src/witness_manager.rs b/witness/src/witness_manager.rs index 4713baf9..b8cc653e 100644 --- a/witness/src/witness_manager.rs +++ b/witness/src/witness_manager.rs @@ -1,7 +1,7 @@ use std::sync::{Arc, RwLock}; use std::path::PathBuf; -use proofman_common::{ProofCtx, SetupCtx}; +use proofman_common::{ModeName, ProofCtx, SetupCtx}; use proofman_util::{timer_start_info, timer_stop_and_log_info}; use crate::WitnessComponent; @@ -46,14 +46,10 @@ impl WitnessManager { } pub fn debug(&self) { - for component in self.components.read().unwrap().iter() { - component.debug(self.pctx.clone()); - } - } - - pub fn end_proof(&self) { - for component in self.components.read().unwrap().iter() { - component.end_proof(); + if self.pctx.options.debug_info.std_mode.name == ModeName::Debug { + for component in self.components.read().unwrap().iter() { + component.debug(self.pctx.clone(), self.sctx.clone()); + } } }