diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/air.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/air.rs index e21643f9..dedc2f6d 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/air.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/air.rs @@ -170,7 +170,8 @@ impl CairoClaimGenerator { let initial_state = input.state_transitions.initial_state; let final_state = input.state_transitions.final_state; let opcodes = OpcodesClaimGenerator::new(input.state_transitions); - let verify_instruction_trace_generator = verify_instruction::ClaimGenerator::default(); + let verify_instruction_trace_generator = + verify_instruction::ClaimGenerator::new(input.instruction_by_pc); let memory_address_to_id_trace_generator = memory_address_to_id::ClaimGenerator::new(&input.memory); let memory_id_to_value_trace_generator = diff --git a/stwo_cairo_prover/crates/prover/src/components/verify_instruction/prover.rs b/stwo_cairo_prover/crates/prover/src/components/verify_instruction/prover.rs index 0a0d93ef..aca84a49 100644 --- a/stwo_cairo_prover/crates/prover/src/components/verify_instruction/prover.rs +++ b/stwo_cairo_prover/crates/prover/src/components/verify_instruction/prover.rs @@ -1,7 +1,8 @@ #![allow(unused_parens)] #![allow(unused_imports)] -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::iter::zip; +use std::sync::atomic::{AtomicU32, Ordering}; use std::vec; use air_structs_derive::SubComponentInputs; @@ -35,6 +36,7 @@ use super::component::{Claim, InteractionClaim}; use crate::components::{ memory_address_to_id, memory_id_to_big, pack_values, range_check_4_3, range_check_7_2_5, }; +use crate::input::decode::{deconstruct_instruction_into_felts, Instruction}; use crate::relations; pub type InputType = (M31, [M31; 3], [M31; 15]); @@ -44,10 +46,24 @@ const N_TRACE_COLUMNS: usize = 28 + N_MULTIPLICITY_COLUMNS; #[derive(Default)] pub struct ClaimGenerator { - /// A map from input to multiplicity. - inputs: BTreeMap, + /// pc -> encoded instruction. + instructions: HashMap, + + /// pc -> multiplicity. + multiplicities: HashMap, } impl ClaimGenerator { + pub fn new(instructions: HashMap) -> Self { + let keys = instructions.keys().copied(); + let mut multiplicities = HashMap::with_capacity(keys.len()); + multiplicities.extend(keys.zip(std::iter::repeat_with(|| AtomicU32::new(0)))); + + Self { + multiplicities, + instructions, + } + } + pub fn write_trace( self, tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, MC>, @@ -60,9 +76,14 @@ impl ClaimGenerator { SimdBackend: BackendForChannel, { let (mut inputs, mut mults) = self - .inputs + .multiplicities .into_iter() - .map(|(input, mult)| (input, M31(mult))) + .map(|(pc, multiplicity)| { + let (offsets, flags) = + deconstruct_instruction_into_felts(*self.instructions.get(&pc).unwrap()); + let multiplicity = M31(multiplicity.into_inner()); + ((pc, offsets, flags), multiplicity) + }) .unzip::<_, _, Vec<_>, Vec<_>>(); let n_calls = inputs.len(); assert_ne!(n_calls, 0); @@ -116,11 +137,19 @@ impl ClaimGenerator { ) } - pub fn add_inputs(&mut self, inputs: &[InputType]) { + pub fn add_inputs(&self, inputs: &[InputType]) { for input in inputs { - *self.inputs.entry(*input).or_default() += 1; + self.add_input(input); } } + + // Instruction is determined by PC. + fn add_input(&self, (pc, ..): &InputType) { + self.multiplicities + .get(pc) + .unwrap() + .fetch_add(1, Ordering::Relaxed); + } } #[derive(SubComponentInputs, ParIterMut, IterMut, Uninitialized)] diff --git a/stwo_cairo_prover/crates/prover/src/input/decode.rs b/stwo_cairo_prover/crates/prover/src/input/decode.rs index 0f1f9966..6d52fa34 100644 --- a/stwo_cairo_prover/crates/prover/src/input/decode.rs +++ b/stwo_cairo_prover/crates/prover/src/input/decode.rs @@ -1,3 +1,5 @@ +use stwo_prover::core::fields::m31::M31; + #[derive(Clone, Debug)] pub struct Instruction { pub offset0: i16, @@ -58,3 +60,30 @@ impl Instruction { } } } + +/// Constructs the input for the DecodeInstruction routine. +/// +/// # Arguments +/// +/// - `encoded_instr`: The encoded instruction. +/// +/// # Returns +/// +/// The Deconstructed instruction in the form of (offsets, flags): ([M31;3], [M31;15]). +pub fn deconstruct_instruction_into_felts(mut encoded_instr: u64) -> ([M31; 3], [M31; 15]) { + let mut next_offset = || { + let offset = (encoded_instr & 0xffff) as u16; + encoded_instr >>= 16; + offset + }; + let offsets = std::array::from_fn(|_| M31(next_offset() as u32)); + + let mut next_bit = || { + let bit = encoded_instr & 1; + encoded_instr >>= 1; + bit + }; + let flags = std::array::from_fn(|_| M31(next_bit() as u32)); + + (offsets, flags) +} diff --git a/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs b/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs index 3fee688a..568343c5 100644 --- a/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs +++ b/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs @@ -175,7 +175,6 @@ pub struct StateTransitions { pub initial_state: CasmState, pub final_state: CasmState, pub casm_states_by_opcode: CasmStatesByOpcode, - pub instruction_by_pc: HashMap, } impl StateTransitions {