Skip to content

Commit

Permalink
Debug mode fast implemented (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerTaule authored Jan 28, 2025
1 parent db1fd26 commit ea5b989
Show file tree
Hide file tree
Showing 16 changed files with 532 additions and 164 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 7 additions & 6 deletions common/src/std_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,36 @@ pub struct StdMode {
pub opids: Vec<u64>,
pub n_vals: usize,
pub print_to_file: bool,
pub fast_mode: bool,
}

impl StdMode {
pub const fn new(name: ModeName, opids: Vec<u64>, n_vals: usize, print_to_file: bool) -> Self {
pub const fn new(name: ModeName, opids: Vec<u64>, n_vals: usize, print_to_file: bool, fast_mode: bool) -> Self {
if name.as_usize() != ModeName::Standard.as_usize() && n_vals == 0 {
panic!("n_vals must be greater than 0");
}

Self { name, opids, n_vals, print_to_file }
Self { name, opids, n_vals, print_to_file, fast_mode }
}

pub fn new_debug() -> Self {
Self::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false)
Self::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false, true)
}
}

impl From<u8> for StdMode {
fn from(v: u8) -> Self {
match v {
0 => StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false),
1 => StdMode::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false),
0 => StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false, false),
1 => StdMode::new(ModeName::Debug, Vec::new(), DEFAULT_PRINT_VALS, false, true),
_ => panic!("Invalid mode"),
}
}
}

impl Default for StdMode {
fn default() -> Self {
StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false)
StdMode::new(ModeName::Standard, Vec::new(), DEFAULT_PRINT_VALS, false, false)
}
}

Expand Down
11 changes: 10 additions & 1 deletion common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ pub fn skip_prover_instance(
(true, Vec::new())
}

fn default_fast_mode() -> bool {
true
}
#[derive(Debug, Default, Deserialize)]
struct StdDebugMode {
#[serde(default)]
Expand All @@ -78,6 +81,8 @@ struct StdDebugMode {
n_print: Option<usize>,
#[serde(default)]
print_to_file: bool,
#[serde(default = "default_fast_mode")]
fast_mode: bool,
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -200,14 +205,18 @@ pub fn json_to_debug_instances_map(proving_key_path: PathBuf, json_path: String)
let global_constraints = json.global_constraints.unwrap_or_default();

let std_mode = if !airgroup_map.is_empty() {
StdMode::new(ModeName::Standard, Vec::new(), 0, false)
StdMode::new(ModeName::Standard, Vec::new(), 0, false, false)
} else {
let mode = json.std_mode.unwrap_or_default();
let fast_mode =
if mode.opids.is_some() && !mode.opids.as_ref().unwrap().is_empty() { false } else { mode.fast_mode };

StdMode::new(
ModeName::Debug,
mode.opids.unwrap_or_default(),
mode.n_print.unwrap_or(DEFAULT_PRINT_VALS),
mode.print_to_file,
fast_mode,
)
};

Expand Down
4 changes: 2 additions & 2 deletions examples/fibonacci-square/src/fibonacci.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::Arc;

use proofman_common::{add_air_instance, AirInstance, FromTrace, ProofCtx};
use proofman_common::{add_air_instance, AirInstance, FromTrace, ProofCtx, SetupCtx};
use witness::WitnessComponent;

use p3_field::PrimeField64;
Expand Down Expand Up @@ -70,7 +70,7 @@ impl<F: PrimeField64 + Copy> WitnessComponent<F> for FibonacciSquare<F> {
add_air_instance::<F>(air_instance, pctx.clone());
}

fn debug(&self, _pctx: Arc<ProofCtx<F>>) {
fn debug(&self, _pctx: Arc<ProofCtx<F>>, _sctx: Arc<SetupCtx>) {
// let trace = FibonacciSquareTrace::from_vec(pctx.get_air_instance_trace(0, 0, 0));
// let air_values = FibonacciSquareAirValues::from_vec(pctx.get_air_instance_air_values(0, 0, 0));
// let airgroup_values = FibonacciSquareAirGroupValues::from_vec(pctx.get_air_instance_airgroup_values(0, 0, 0));
Expand Down
8 changes: 4 additions & 4 deletions hints/src/hints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ fn get_hint_f<F: Field>(
pctx: Option<&ProofCtx<F>>,
airgroup_id: usize,
air_id: usize,
air_instance: Option<&mut AirInstance<F>>,
air_instance: Option<&AirInstance<F>>,
hint_id: usize,
hint_field_name: &str,
options: HintFieldOptions,
Expand Down Expand Up @@ -1001,7 +1001,7 @@ fn get_hint_f<F: Field>(
pub fn get_hint_field<F: Field>(
sctx: &SetupCtx,
pctx: &ProofCtx<F>,
air_instance: &mut AirInstance<F>,
air_instance: &AirInstance<F>,
hint_id: usize,
hint_field_name: &str,
options: HintFieldOptions,
Expand Down Expand Up @@ -1031,7 +1031,7 @@ pub fn get_hint_field<F: Field>(
pub fn get_hint_field_a<F: Field>(
sctx: &SetupCtx,
pctx: &ProofCtx<F>,
air_instance: &mut AirInstance<F>,
air_instance: &AirInstance<F>,
hint_id: usize,
hint_field_name: &str,
options: HintFieldOptions,
Expand Down Expand Up @@ -1065,7 +1065,7 @@ pub fn get_hint_field_a<F: Field>(
pub fn get_hint_field_m<F: Field>(
sctx: &SetupCtx,
pctx: &ProofCtx<F>,
air_instance: &mut AirInstance<F>,
air_instance: &AirInstance<F>,
hint_id: usize,
hint_field_name: &str,
options: HintFieldOptions,
Expand Down
1 change: 1 addition & 0 deletions pil2-components/lib/std/rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ p3-goldilocks.workspace = true
p3-field.workspace = true
rayon.workspace = true
witness.workspace = true
colored.workspace = true
13 changes: 1 addition & 12 deletions pil2-components/lib/std/rs/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use p3_field::PrimeField;
use num_traits::ToPrimitive;
use proofman_common::{AirInstance, ProofCtx, SetupCtx};
use proofman_common::{ProofCtx, SetupCtx};
use proofman_hints::{
get_hint_field_constant_gc, get_hint_field_constant, get_hint_field_constant_a, HintFieldOptions, HintFieldOutput,
HintFieldValue,
Expand All @@ -13,17 +13,6 @@ pub trait AirComponent<F> {

fn new(pctx: Arc<ProofCtx<F>>, sctx: Arc<SetupCtx>, airgroup_id: Option<usize>, air_id: Option<usize>)
-> Arc<Self>;

fn debug_mode(
&self,
_pctx: &ProofCtx<F>,
_sctx: &SetupCtx,
_air_instance: &mut AirInstance<F>,
_air_instance_id: usize,
_num_rows: usize,
_debug_data_hints: Vec<u64>,
) {
}
}

// Helper to extract hint fields
Expand Down
148 changes: 139 additions & 9 deletions pil2-components/lib/std/rs/src/debug.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
use std::{
collections::HashMap,
fs::{self, File},
hash::{DefaultHasher, Hasher, Hash},
io::{self, Write},
path::{Path, PathBuf},
sync::Mutex,
};

use p3_field::PrimeField;
use proofman_common::ProofCtx;
use proofman_hints::{format_vec, HintFieldOutput};

pub type DebugData<F> = Mutex<HashMap<F, HashMap<Vec<HintFieldOutput<F>>, BusValue<F>>>>; // opid -> val -> BusValue
use num_bigint::BigUint;
use num_traits::Zero;

use colored::*;

pub type DebugData<F> = HashMap<F, HashMap<Vec<HintFieldOutput<F>>, BusValue<F>>>; // opid -> val -> BusValue

pub type DebugDataFast<F> = HashMap<F, SharedDataFast>; // opid -> sharedDataFast

#[derive(Debug)]
pub struct BusValue<F> {
Expand All @@ -25,6 +32,14 @@ struct SharedData<F> {
num_assumes: F,
}

#[derive(Clone, Debug)]
pub struct SharedDataFast {
pub num_proves: BigUint,
pub num_assumes: BigUint,
pub num_proves_global: Vec<BigUint>,
pub num_assumes_global: Vec<BigUint>,
}

type AirGroupMap = HashMap<usize, AirMap>;
type AirMap = HashMap<usize, AirData>;

Expand All @@ -43,9 +58,63 @@ struct InstanceData {
row_assumes: Vec<usize>,
}

#[allow(clippy::too_many_arguments)]
pub fn update_debug_data_fast<F: PrimeField>(
debug_data_fast: &mut DebugDataFast<F>,
opid: F,
val: Vec<HintFieldOutput<F>>,
proves: bool,
times: F,
is_global: bool,
) {
let bus_opid_times = debug_data_fast.entry(opid).or_insert_with(|| SharedDataFast {
num_assumes_global: Vec::new(),
num_proves_global: Vec::new(),
num_proves: BigUint::zero(),
num_assumes: BigUint::zero(),
});

let mut values = Vec::new();
for value in val.iter() {
match value {
HintFieldOutput::Field(f) => values.push(*f),
HintFieldOutput::FieldExtended(ef) => {
values.push(ef.value[0]);
values.push(ef.value[1]);
values.push(ef.value[2]);
}
}
}

let mut hasher = DefaultHasher::new();
values.hash(&mut hasher);

let hash_value = BigUint::from(hasher.finish());

if is_global {
if proves {
// Check if bus op id times num proves global contains value
if bus_opid_times.num_proves_global.contains(&hash_value) {
return;
}
bus_opid_times.num_proves_global.push(hash_value * times.as_canonical_biguint());
} else {
if bus_opid_times.num_assumes_global.contains(&hash_value) {
return;
}
bus_opid_times.num_assumes_global.push(hash_value);
}
} else if proves {
bus_opid_times.num_proves += hash_value * times.as_canonical_biguint();
} else {
assert!(times.is_one(), "The selector value is invalid: expected 1, but received {:?}.", times);
bus_opid_times.num_assumes += hash_value;
}
}

#[allow(clippy::too_many_arguments)]
pub fn update_debug_data<F: PrimeField>(
debug_data: &DebugData<F>,
debug_data: &mut DebugData<F>,
name_piop: &str,
name_expr: &[String],
opid: F,
Expand All @@ -58,9 +127,7 @@ pub fn update_debug_data<F: PrimeField>(
times: F,
is_global: bool,
) {
let mut bus = debug_data.lock().expect("Bus values missing");

let bus_opid = bus.entry(opid).or_default();
let bus_opid = debug_data.entry(opid).or_default();

let bus_val = bus_opid.entry(val).or_insert_with(|| BusValue {
shared_data: SharedData { direct_was_called: false, num_proves: F::zero(), num_assumes: F::zero() },
Expand Down Expand Up @@ -99,18 +166,77 @@ pub fn update_debug_data<F: PrimeField>(
}
}

pub fn check_invalid_opids<F: PrimeField>(
_pctx: &ProofCtx<F>,
name: &str,
debugs_data_fasts: &mut [DebugDataFast<F>],
) -> Vec<F> {
let mut debug_data_fast = HashMap::new();

let mut global_assumes = Vec::new();
let mut global_proves = Vec::new();
for map in debugs_data_fasts {
for (opid, bus) in map.iter() {
if debug_data_fast.contains_key(opid) {
let bus_fast: &mut SharedDataFast = debug_data_fast.get_mut(opid).unwrap();
for assume_global in bus.num_assumes_global.iter() {
if global_assumes.contains(assume_global) {
continue;
}
global_assumes.push(assume_global.clone());
bus_fast.num_assumes += assume_global;
}
for prove_global in bus.num_proves_global.iter() {
if global_proves.contains(prove_global) {
continue;
}
global_proves.push(prove_global.clone());
bus_fast.num_proves += prove_global;
}

bus_fast.num_proves += bus.num_proves.clone();
bus_fast.num_assumes += bus.num_assumes.clone();
} else {
debug_data_fast.insert(*opid, bus.clone());
}
}
}

// TODO: SINCRONIZATION IN DISTRIBUTED MODE

let mut invalid_opids = Vec::new();

// Check if there are any invalid opids

for (opid, bus) in debug_data_fast.iter_mut() {
if bus.num_proves != bus.num_assumes {
invalid_opids.push(*opid);
}
}

if !invalid_opids.is_empty() {
log::error!(
"{}: ··· {}",
name,
format!("\u{2717} The following opids does not match {:?}", invalid_opids).bright_red().bold()
);
} else {
log::info!("{}: ··· {}", name, "\u{2713} All bus values match.".bright_green().bold());
}

invalid_opids
}
pub fn print_debug_info<F: PrimeField>(
pctx: &ProofCtx<F>,
name: &str,
max_values_to_print: usize,
print_to_file: bool,
debug_data: &DebugData<F>,
debug_data: &mut DebugData<F>,
) {
let mut file_path = PathBuf::new();
let mut output: Box<dyn Write> = Box::new(io::stdout());
let mut there_are_errors = false;
let mut bus_vals = debug_data.lock().expect("Bus values missing");
for (opid, bus) in bus_vals.iter_mut() {
for (opid, bus) in debug_data.iter_mut() {
if bus.iter().any(|(_, v)| v.shared_data.num_proves != v.shared_data.num_assumes) {
if !there_are_errors {
// Print to a file if requested
Expand Down Expand Up @@ -204,6 +330,10 @@ pub fn print_debug_info<F: PrimeField>(
}
}

if !there_are_errors {
log::info!("{}: ··· {}", name, "\u{2713} All bus values match.".bright_green().bold());
}

fn print_diffs<F: PrimeField>(
pctx: &ProofCtx<F>,
val: &[HintFieldOutput<F>],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,13 @@ impl<F: PrimeField> WitnessComponent<F> for SpecifiedRanges<F> {
let buffer = create_buffer_fast(buffer_size as usize);

// Add a new air instance. Since Specified Ranges is a table, only this air instance is needed
let mut air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer));
let air_instance = AirInstance::new(TraceInfo::new(self.airgroup_id, self.air_id, buffer));
let mut mul_columns_guard = self.mul_columns.lock().unwrap();
for hint in hints_guard[1..].iter() {
mul_columns_guard.push(get_hint_field::<F>(
&sctx,
&pctx,
&mut air_instance,
&air_instance,
hint.to_usize().unwrap(),
"reference",
HintFieldOptions::dest_with_zeros(),
Expand Down
Loading

0 comments on commit ea5b989

Please sign in to comment.