Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

atomic mults in verify #328

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion stwo_cairo_prover/crates/prover/src/cairo_air/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ impl CairoClaimGenerator {
let initial_state = input.state_transitions.initial_state;
let final_state = input.state_transitions.final_state;
let opcodes = OpcodesClaimGenerator::new(input.state_transitions);
let verify_instruction_trace_generator = verify_instruction::ClaimGenerator::default();
let verify_instruction_trace_generator =
verify_instruction::ClaimGenerator::new(input.instruction_by_pc);
let memory_address_to_id_trace_generator =
memory_address_to_id::ClaimGenerator::new(&input.memory);
let memory_id_to_value_trace_generator =
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#![allow(unused_parens)]
#![allow(unused_imports)]
use std::collections::{BTreeMap, BTreeSet};
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::iter::zip;
use std::sync::atomic::{AtomicU32, Ordering};
use std::vec;

use air_structs_derive::SubComponentInputs;
Expand Down Expand Up @@ -35,6 +36,7 @@ use super::component::{Claim, InteractionClaim};
use crate::components::{
memory_address_to_id, memory_id_to_big, pack_values, range_check_4_3, range_check_7_2_5,
};
use crate::input::decode::{deconstruct_instruction, Instruction};
use crate::relations;

pub type InputType = (M31, [M31; 3], [M31; 15]);
Expand All @@ -44,10 +46,24 @@ const N_TRACE_COLUMNS: usize = 28 + N_MULTIPLICITY_COLUMNS;

#[derive(Default)]
pub struct ClaimGenerator {
/// A map from input to multiplicity.
inputs: BTreeMap<InputType, u32>,
/// pc -> encoded instruction.
instructions: HashMap<M31, u64>,

/// pc -> multiplicity.
multiplicities: HashMap<M31, AtomicU32>,
}
impl ClaimGenerator {
pub fn new(instructions: HashMap<M31, u64>) -> Self {
let keys = instructions.keys().copied();
let mut multiplicities = HashMap::with_capacity(keys.len());
multiplicities.extend(keys.zip(std::iter::repeat_with(|| AtomicU32::new(0))));

Self {
multiplicities,
instructions,
}
}

pub fn write_trace<MC: MerkleChannel>(
self,
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, MC>,
Expand All @@ -60,9 +76,14 @@ impl ClaimGenerator {
SimdBackend: BackendForChannel<MC>,
{
let (mut inputs, mut mults) = self
.inputs
.multiplicities
.into_iter()
.map(|(input, mult)| (input, M31(mult)))
.map(|(pc, multiplicity)| {
let (offsets, flags) =
deconstruct_instruction(*self.instructions.get(&pc).unwrap());
let multiplicity = M31(multiplicity.into_inner());
((pc, offsets, flags), multiplicity)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
let n_calls = inputs.len();
assert_ne!(n_calls, 0);
Expand Down Expand Up @@ -116,11 +137,19 @@ impl ClaimGenerator {
)
}

pub fn add_inputs(&mut self, inputs: &[InputType]) {
pub fn add_inputs(&self, inputs: &[InputType]) {
for input in inputs {
*self.inputs.entry(*input).or_default() += 1;
self.add_input(input);
}
}

// Instruction is determined by PC.
fn add_input(&self, (pc, ..): &InputType) {
self.multiplicities
.get(pc)
.unwrap()
.fetch_add(1, Ordering::Relaxed);
}
}

#[derive(SubComponentInputs, ParIterMut, IterMut, Uninitialized)]
Expand Down
48 changes: 48 additions & 0 deletions stwo_cairo_prover/crates/prover/src/input/decode.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use stwo_prover::core::fields::m31::M31;

#[derive(Clone, Debug)]
pub struct Instruction {
pub offset0: i16,
Expand Down Expand Up @@ -58,3 +60,49 @@ impl Instruction {
}
}
}

/// Constructs the input for the DecodeInstruction routine.
///
/// # Arguments
///
/// - `encoded_instr`: The encoded instruction.
///
/// # Returns
///
/// The Deconstructed instruction in the form of (offsets, flags): ([M31;3], [M31;15]).
pub fn deconstruct_instruction(mut encoded_instr: u64) -> ([M31; 3], [M31; 15]) {
let mut next_offset = || {
let offset = (encoded_instr & 0xffff) as u16;
encoded_instr >>= 16;
offset
};
let offsets = std::array::from_fn(|_| M31(next_offset() as u32));

let mut next_bit = || {
let bit = encoded_instr & 1;
encoded_instr >>= 1;
bit
};
let flags = std::array::from_fn(|_| M31(next_bit() as u32));

(offsets, flags)
}

#[cfg(test)]
mod tests {
use stwo_prover::core::fields::m31::M31;

use crate::input::decode::deconstruct_instruction;

#[test]
fn test_deconstruct_instruction() {
let encoded_instr = 0b0010101010101010000000000000000100000000000000110000000000000111;
let expected_flags = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0].map(M31);
let expected_offsets = [7, 3, 1].map(M31);

let (offsets, flags) = deconstruct_instruction(encoded_instr);

assert_eq!(offsets, expected_offsets);
assert_eq!(flags, expected_flags);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ pub struct StateTransitions {
pub initial_state: CasmState,
pub final_state: CasmState,
pub casm_states_by_opcode: CasmStatesByOpcode,
pub instruction_by_pc: HashMap<M31, u64>,
}

impl StateTransitions {
Expand Down
Loading