Skip to content

Commit

Permalink
atomic mults in verify
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Jan 9, 2025
1 parent d272452 commit 1ae8480
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 9 deletions.
3 changes: 2 additions & 1 deletion stwo_cairo_prover/crates/prover/src/cairo_air/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ impl CairoClaimGenerator {
let initial_state = input.state_transitions.initial_state;
let final_state = input.state_transitions.final_state;
let opcodes = OpcodesClaimGenerator::new(input.state_transitions);
let verify_instruction_trace_generator = verify_instruction::ClaimGenerator::default();
let verify_instruction_trace_generator =
verify_instruction::ClaimGenerator::new(input.instruction_by_pc);
let memory_address_to_id_trace_generator =
memory_address_to_id::ClaimGenerator::new(&input.memory);
let memory_id_to_value_trace_generator =
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#![allow(unused_parens)]
#![allow(unused_imports)]
use std::collections::{BTreeMap, BTreeSet};
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::iter::zip;
use std::sync::atomic::{AtomicU32, Ordering};
use std::vec;

use air_structs_derive::SubComponentInputs;
Expand Down Expand Up @@ -35,6 +36,7 @@ use super::component::{Claim, InteractionClaim};
use crate::components::{
memory_address_to_id, memory_id_to_big, pack_values, range_check_4_3, range_check_7_2_5,
};
use crate::input::decode::{deconstruct_instruction, Instruction};
use crate::relations;

pub type InputType = (M31, [M31; 3], [M31; 15]);
Expand All @@ -44,10 +46,24 @@ const N_TRACE_COLUMNS: usize = 28 + N_MULTIPLICITY_COLUMNS;

#[derive(Default)]
pub struct ClaimGenerator {
/// A map from input to multiplicity.
inputs: BTreeMap<InputType, u32>,
/// pc -> encoded instruction.
instructions: HashMap<M31, u64>,

/// pc -> multiplicity.
multiplicities: HashMap<M31, AtomicU32>,
}
impl ClaimGenerator {
pub fn new(instructions: HashMap<M31, u64>) -> Self {
let keys = instructions.keys().copied();
let mut multiplicities = HashMap::with_capacity(keys.len());
multiplicities.extend(keys.zip(std::iter::repeat_with(|| AtomicU32::new(0))));

Self {
multiplicities,
instructions,
}
}

pub fn write_trace<MC: MerkleChannel>(
self,
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, MC>,
Expand All @@ -60,9 +76,14 @@ impl ClaimGenerator {
SimdBackend: BackendForChannel<MC>,
{
let (mut inputs, mut mults) = self
.inputs
.multiplicities
.into_iter()
.map(|(input, mult)| (input, M31(mult)))
.map(|(pc, multiplicity)| {
let (offsets, flags) =
deconstruct_instruction(*self.instructions.get(&pc).unwrap());
let multiplicity = M31(multiplicity.into_inner());
((pc, offsets, flags), multiplicity)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
let n_calls = inputs.len();
assert_ne!(n_calls, 0);
Expand Down Expand Up @@ -116,11 +137,19 @@ impl ClaimGenerator {
)
}

pub fn add_inputs(&mut self, inputs: &[InputType]) {
pub fn add_inputs(&self, inputs: &[InputType]) {
for input in inputs {
*self.inputs.entry(*input).or_default() += 1;
self.add_input(input);
}
}

// Instruction is determined by PC.
fn add_input(&self, (pc, ..): &InputType) {
self.multiplicities
.get(pc)
.unwrap()
.fetch_add(1, Ordering::Relaxed);
}
}

#[derive(SubComponentInputs, ParIterMut, IterMut, Uninitialized)]
Expand Down
48 changes: 48 additions & 0 deletions stwo_cairo_prover/crates/prover/src/input/decode.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use stwo_prover::core::fields::m31::M31;

#[derive(Clone, Debug)]
pub struct Instruction {
pub offset0: i16,
Expand Down Expand Up @@ -58,3 +60,49 @@ impl Instruction {
}
}
}

/// Constructs the input for the DecodeInstruction routine.
///
/// # Arguments
///
/// - `encoded_instr`: The encoded instruction.
///
/// # Returns
///
/// The Deconstructed instruction in the form of (offsets, flags): ([M31;3], [M31;15]).
pub fn deconstruct_instruction(mut encoded_instr: u64) -> ([M31; 3], [M31; 15]) {
let mut next_offset = || {
let offset = (encoded_instr & 0xffff) as u16;
encoded_instr >>= 16;
offset
};
let offsets = std::array::from_fn(|_| M31(next_offset() as u32));

let mut next_bit = || {
let bit = encoded_instr & 1;
encoded_instr >>= 1;
bit
};
let flags = std::array::from_fn(|_| M31(next_bit() as u32));

(offsets, flags)
}

#[cfg(test)]
mod tests {
use stwo_prover::core::fields::m31::M31;

use crate::input::decode::deconstruct_instruction;

#[test]
fn test_deconstruct_instruction() {
let encoded_instr = 0b0010101010101010000000000000000100000000000000110000000000000111;
let expected_flags = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0].map(M31);
let expected_offsets = [7, 3, 1].map(M31);

let (offsets, flags) = deconstruct_instruction(encoded_instr);

assert_eq!(offsets, expected_offsets);
assert_eq!(flags, expected_flags);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ pub struct StateTransitions {
pub initial_state: CasmState,
pub final_state: CasmState,
pub casm_states_by_opcode: CasmStatesByOpcode,
pub instruction_by_pc: HashMap<M31, u64>,
}

impl StateTransitions {
Expand Down

0 comments on commit 1ae8480

Please sign in to comment.