Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug mode fast implemented #151

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 7 additions & 6 deletions common/src/std_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,36 @@ pub struct StdMode {
pub opids: Vec<u64>,
pub n_vals: usize,
pub print_to_file: bool,
pub fast_mode: bool,
}

impl StdMode {
pub const fn new(name: ModeName, opids: Vec<u64>, n_vals: usize, print_to_file: bool) -> Self {
pub const fn new(name: ModeName, opids: Vec<u64>, n_vals: usize, print_to_file: bool, fast_mode: bool) -> Self {
if name.as_usize() != ModeName::Standard.as_usize() && n_vals == 0 {
panic!("n_vals must be greater than 0");
}

Self { name, opids, n_vals, print_to_file }
Self { name, opids, n_vals, print_to_file, fast_mode }
}

pub fn new_debug() -> Self {
Self::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false)
Self::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false, true)
}
}

impl From<u8> for StdMode {
fn from(v: u8) -> Self {
match v {
0 => StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false),
1 => StdMode::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false),
0 => StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false, false),
1 => StdMode::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false, true),
_ => panic!("Invalid mode"),
}
}
}

impl Default for StdMode {
fn default() -> Self {
StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false)
StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false, false)
}
}

Expand Down
11 changes: 10 additions & 1 deletion common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ pub fn skip_prover_instance(
(true, Vec::new())
}

fn default_fast_mode() -> bool {
true
}
#[derive(Debug, Default, Deserialize)]
struct StdDebugMode {
#[serde(default)]
Expand All @@ -78,6 +81,8 @@ struct StdDebugMode {
n_print: Option<usize>,
#[serde(default)]
print_to_file: bool,
#[serde(default = "default_fast_mode")]
fast_mode: bool,
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -200,14 +205,18 @@ pub fn json_to_debug_instances_map(proving_key_path: PathBuf, json_path: String)
let global_constraints = json.global_constraints.unwrap_or_default();

let std_mode = if !airgroup_map.is_empty() {
StdMode::new(ModeName::Standard, Vec::new(), 0, false)
StdMode::new(ModeName::Standard, Vec::new(), 0, false, false)
} else {
let mode = json.std_mode.unwrap_or_default();
let fast_mode =
if mode.opids.is_some() && !mode.opids.as_ref().unwrap().is_empty() { false } else { mode.fast_mode };

StdMode::new(
ModeName::Debug,
mode.opids.unwrap_or_default(),
mode.n_print.unwrap_or(DEFAULT_PRINT_VALS),
mode.print_to_file,
fast_mode,
)
};

Expand Down
4 changes: 2 additions & 2 deletions examples/fibonacci-square/src/fibonacci.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::Arc;

use proofman_common::{add_air_instance, AirInstance, FromTrace, ProofCtx};
use proofman_common::{add_air_instance, AirInstance, FromTrace, ProofCtx, SetupCtx};
use witness::WitnessComponent;

use p3_field::PrimeField64;
Expand Down Expand Up @@ -70,7 +70,7 @@ impl<F: PrimeField64 + Copy> WitnessComponent<F> for FibonacciSquare<F> {
add_air_instance::<F>(air_instance, pctx.clone());
}

fn debug(&self, _pctx: Arc<ProofCtx<F>>) {
fn debug(&self, _pctx: Arc<ProofCtx<F>>, _sctx: Arc<SetupCtx>) {
// let trace = FibonacciSquareTrace::from_vec(pctx.get_air_instance_trace(0, 0, 0));
// let air_values = FibonacciSquareAirValues::from_vec(pctx.get_air_instance_air_values(0, 0, 0));
// let airgroup_values = FibonacciSquareAirGroupValues::from_vec(pctx.get_air_instance_airgroup_values(0, 0, 0));
Expand Down
2 changes: 1 addition & 1 deletion hints/src/global_hints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub fn aggregate_airgroupvals<F: Field>(pctx: Arc<ProofCtx<F>>) -> Vec<Vec<u64>>
airgroupvalues.push(values);
}

for (_, air_instance) in pctx.air_instance_repo.air_instances.write().unwrap().iter() {
for (_, air_instance) in pctx.air_instance_repo.air_instances.read().unwrap().iter() {
for (idx, agg_type) in pctx.global_info.agg_types[air_instance.airgroup_id].iter().enumerate() {
let mut acc = ExtensionField {
value: [
Expand Down
8 changes: 4 additions & 4 deletions hints/src/hints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ fn get_hint_f<F: Field>(
pctx: Option<&ProofCtx<F>>,
airgroup_id: usize,
air_id: usize,
air_instance: Option<&mut AirInstance<F>>,
air_instance: Option<&AirInstance<F>>,
hint_id: usize,
hint_field_name: &str,
options: HintFieldOptions,
Expand Down Expand Up @@ -1001,7 +1001,7 @@ fn get_hint_f<F: Field>(
pub fn get_hint_field<F: Field>(
sctx: &SetupCtx,
pctx: &ProofCtx<F>,
air_instance: &mut AirInstance<F>,
air_instance: &AirInstance<F>,
hint_id: usize,
hint_field_name: &str,
options: HintFieldOptions,
Expand Down Expand Up @@ -1031,7 +1031,7 @@ pub fn get_hint_field<F: Field>(
pub fn get_hint_field_a<F: Field>(
sctx: &SetupCtx,
pctx: &ProofCtx<F>,
air_instance: &mut AirInstance<F>,
air_instance: &AirInstance<F>,
hint_id: usize,
hint_field_name: &str,
options: HintFieldOptions,
Expand Down Expand Up @@ -1065,7 +1065,7 @@ pub fn get_hint_field_a<F: Field>(
pub fn get_hint_field_m<F: Field>(
sctx: &SetupCtx,
pctx: &ProofCtx<F>,
air_instance: &mut AirInstance<F>,
air_instance: &AirInstance<F>,
hint_id: usize,
hint_field_name: &str,
options: HintFieldOptions,
Expand Down
1 change: 1 addition & 0 deletions pil2-components/lib/std/rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ p3-goldilocks.workspace = true
p3-field.workspace = true
rayon.workspace = true
witness.workspace = true
colored.workspace = true
13 changes: 1 addition & 12 deletions pil2-components/lib/std/rs/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use p3_field::PrimeField;
use num_traits::ToPrimitive;
use proofman_common::{AirInstance, ProofCtx, SetupCtx};
use proofman_common::{ProofCtx, SetupCtx};
use proofman_hints::{
get_hint_field_constant_gc, get_hint_field_constant, get_hint_field_constant_a, HintFieldOptions, HintFieldOutput,
HintFieldValue,
Expand All @@ -13,17 +13,6 @@ pub trait AirComponent<F> {

fn new(pctx: Arc<ProofCtx<F>>, sctx: Arc<SetupCtx>, airgroup_id: Option<usize>, air_id: Option<usize>)
-> Arc<Self>;

fn debug_mode(
&self,
_pctx: &ProofCtx<F>,
_sctx: &SetupCtx,
_air_instance: &mut AirInstance<F>,
_air_instance_id: usize,
_num_rows: usize,
_debug_data_hints: Vec<u64>,
) {
}
}

// Helper to extract hint fields
Expand Down
148 changes: 139 additions & 9 deletions pil2-components/lib/std/rs/src/debug.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
use std::{
collections::HashMap,
fs::{self, File},
hash::{DefaultHasher, Hasher, Hash},
io::{self, Write},
path::{Path, PathBuf},
sync::Mutex,
};

use p3_field::PrimeField;
use proofman_common::ProofCtx;
use proofman_hints::{format_vec, HintFieldOutput};

pub type DebugData<F> = Mutex<HashMap<F, HashMap<Vec<HintFieldOutput<F>>, BusValue<F>>>>; // opid -> val -> BusValue
use num_bigint::BigUint;
use num_traits::Zero;

use colored::*;

pub type DebugData<F> = HashMap<F, HashMap<Vec<HintFieldOutput<F>>, BusValue<F>>>; // opid -> val -> BusValue

pub type DebugDataFast<F> = HashMap<F, SharedDataFast>; // opid -> sharedDataFast

#[derive(Debug)]
pub struct BusValue<F> {
Expand All @@ -25,6 +32,14 @@ struct SharedData<F> {
num_assumes: F,
}

#[derive(Clone, Debug)]
pub struct SharedDataFast {
pub num_proves: BigUint,
pub num_assumes: BigUint,
pub num_proves_global: Vec<BigUint>,
pub num_assumes_global: Vec<BigUint>,
}

type AirGroupMap = HashMap<usize, AirMap>;
type AirMap = HashMap<usize, AirData>;

Expand All @@ -43,9 +58,63 @@ struct InstanceData {
row_assumes: Vec<usize>,
}

#[allow(clippy::too_many_arguments)]
pub fn update_debug_data_fast<F: PrimeField>(
debug_data_fast: &mut DebugDataFast<F>,
opid: F,
val: Vec<HintFieldOutput<F>>,
proves: bool,
times: F,
is_global: bool,
) {
let bus_opid_times = debug_data_fast.entry(opid).or_insert_with(|| SharedDataFast {
num_assumes_global: Vec::new(),
num_proves_global: Vec::new(),
num_proves: BigUint::zero(),
num_assumes: BigUint::zero(),
});

let mut values = Vec::new();
for value in val.iter() {
match value {
HintFieldOutput::Field(f) => values.push(*f),
HintFieldOutput::FieldExtended(ef) => {
values.push(ef.value[0]);
values.push(ef.value[1]);
values.push(ef.value[2]);
}
}
}

let mut hasher = DefaultHasher::new();
values.hash(&mut hasher);

let hash_value = BigUint::from(hasher.finish());

if is_global {
if proves {
// Check if bus op id times num proves global contains value
if bus_opid_times.num_proves_global.contains(&hash_value) {
return;
}
bus_opid_times.num_proves_global.push(hash_value * times.as_canonical_biguint());
} else {
if bus_opid_times.num_assumes_global.contains(&hash_value) {
return;
}
bus_opid_times.num_assumes_global.push(hash_value);
}
} else if proves {
bus_opid_times.num_proves += hash_value * times.as_canonical_biguint();
} else {
assert!(times.is_one(), "The selector value is invalid: expected 1, but received {:?}.", times);
bus_opid_times.num_assumes += hash_value;
}
}

#[allow(clippy::too_many_arguments)]
pub fn update_debug_data<F: PrimeField>(
debug_data: &DebugData<F>,
debug_data: &mut DebugData<F>,
name_piop: &str,
name_expr: &[String],
opid: F,
Expand All @@ -58,9 +127,7 @@ pub fn update_debug_data<F: PrimeField>(
times: F,
is_global: bool,
) {
let mut bus = debug_data.lock().expect("Bus values missing");

let bus_opid = bus.entry(opid).or_default();
let bus_opid = debug_data.entry(opid).or_default();

let bus_val = bus_opid.entry(val).or_insert_with(|| BusValue {
shared_data: SharedData { direct_was_called: false, num_proves: F::zero(), num_assumes: F::zero() },
Expand Down Expand Up @@ -99,18 +166,77 @@ pub fn update_debug_data<F: PrimeField>(
}
}

pub fn check_invalid_opids<F: PrimeField>(
_pctx: &ProofCtx<F>,
name: &str,
debugs_data_fasts: &mut [DebugDataFast<F>],
) -> Vec<F> {
let mut debug_data_fast = HashMap::new();

let mut global_assumes = Vec::new();
let mut global_proves = Vec::new();
for map in debugs_data_fasts {
for (opid, bus) in map.iter() {
if debug_data_fast.contains_key(opid) {
let bus_fast: &mut SharedDataFast = debug_data_fast.get_mut(opid).unwrap();
for assume_global in bus.num_assumes_global.iter() {
if global_assumes.contains(assume_global) {
continue;
}
global_assumes.push(assume_global.clone());
bus_fast.num_assumes += assume_global;
}
for prove_global in bus.num_proves_global.iter() {
if global_proves.contains(prove_global) {
continue;
}
global_proves.push(prove_global.clone());
bus_fast.num_proves += prove_global;
}

bus_fast.num_proves += bus.num_proves.clone();
bus_fast.num_assumes += bus.num_assumes.clone();
} else {
debug_data_fast.insert(*opid, bus.clone());
}
}
}

// TODO: SINCRONIZATION IN DISTRIBUTED MODE

let mut invalid_opids = Vec::new();

// Check if there are any invalid opids

for (opid, bus) in debug_data_fast.iter_mut() {
if bus.num_proves != bus.num_assumes {
invalid_opids.push(*opid);
}
}

if !invalid_opids.is_empty() {
log::error!(
"{}: ··· {}",
name,
format!("\u{2717} The following opids does not match {:?}", invalid_opids).bright_red().bold()
);
} else {
log::info!("{}: ··· {}", name, "\u{2713} All bus values match.".bright_green().bold());
}

invalid_opids
}
pub fn print_debug_info<F: PrimeField>(
pctx: &ProofCtx<F>,
name: &str,
max_values_to_print: usize,
print_to_file: bool,
debug_data: &DebugData<F>,
debug_data: &mut DebugData<F>,
) {
let mut file_path = PathBuf::new();
let mut output: Box<dyn Write> = Box::new(io::stdout());
let mut there_are_errors = false;
let mut bus_vals = debug_data.lock().expect("Bus values missing");
for (opid, bus) in bus_vals.iter_mut() {
for (opid, bus) in debug_data.iter_mut() {
if bus.iter().any(|(_, v)| v.shared_data.num_proves != v.shared_data.num_assumes) {
if !there_are_errors {
// Print to a file if requested
Expand Down Expand Up @@ -204,6 +330,10 @@ pub fn print_debug_info<F: PrimeField>(
}
}

if !there_are_errors {
log::info!("{}: ··· {}", name, "\u{2713} All bus values match.".bright_green().bold());
}

fn print_diffs<F: PrimeField>(
pctx: &ProofCtx<F>,
val: &[HintFieldOutput<F>],
Expand Down
Loading
Loading