From 7eddb1af422ac425324a14c9df96f12f1e989b40 Mon Sep 17 00:00:00 2001 From: ohad-starkware Date: Wed, 8 Jan 2025 11:09:55 +0200 Subject: [PATCH] instruction cache outside of statetransitions --- .../crates/prover/src/input/mod.rs | 6 +++- .../prover/src/input/state_transitions.rs | 31 +++++++++++++------ .../crates/prover/src/input/vm_import/mod.rs | 5 ++- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/stwo_cairo_prover/crates/prover/src/input/mod.rs b/stwo_cairo_prover/crates/prover/src/input/mod.rs index 04657f6f..1bfb4119 100644 --- a/stwo_cairo_prover/crates/prover/src/input/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/mod.rs @@ -1,9 +1,12 @@ +use std::collections::HashMap; + use builtin_segments::BuiltinSegments; use memory::Memory; +use prover_types::cpu::M31; use state_transitions::StateTransitions; pub mod builtin_segments; -mod decode; +pub mod decode; pub mod memory; pub mod plain; pub mod range_check_unit; @@ -16,6 +19,7 @@ pub const N_REGISTERS: usize = 3; #[derive(Debug)] pub struct ProverInput { pub state_transitions: StateTransitions, + pub instruction_by_pc: HashMap, pub memory: Memory, pub public_memory_addresses: Vec, pub builtins_segments: BuiltinSegments, diff --git a/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs b/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs index 06536d14..3fee688a 100644 --- a/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs +++ b/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs @@ -180,19 +180,26 @@ pub struct StateTransitions { impl StateTransitions { /// Iterates over the casm states and splits them into the appropriate opcode components. + /// + /// # Returns + /// + /// - StateTransitions, used to feed the opcodes' air. + /// - A map from pc to instruction that is used to feed + /// [`crate::components::verify_instruction::ClaimGenerator`]. pub fn from_iter( iter: impl Iterator, memory: &mut MemoryBuilder, dev_mode: bool, - ) -> Self { + ) -> (Self, HashMap) { let mut res = Self::default(); + let mut instruction_by_pc = HashMap::new(); let mut iter = iter.peekable(); let Some(first) = iter.next() else { - return res; + return (res, instruction_by_pc); }; res.initial_state = first.into(); - res.push_instr(memory, first.into(), dev_mode); + res.push_instr(memory, first.into(), dev_mode, &mut instruction_by_pc); while let Some(entry) = iter.next() { // TODO(Ohad): Check if the adapter outputs the final state. @@ -200,18 +207,24 @@ impl StateTransitions { res.final_state = entry.into(); break; }; - res.push_instr(memory, entry.into(), dev_mode); + res.push_instr(memory, entry.into(), dev_mode, &mut instruction_by_pc); } - res + (res, instruction_by_pc) } // TODO(Ohad): remove dev_mode after adding the rest of the instructions. /// Pushes the state transition at pc into the appropriate opcode component. - fn push_instr(&mut self, memory: &mut MemoryBuilder, state: CasmState, dev_mode: bool) { + fn push_instr( + &mut self, + memory: &mut MemoryBuilder, + state: CasmState, + dev_mode: bool, + instruction_by_pc: &mut HashMap, + ) { let CasmState { ap, fp, pc } = state; - let instruction = memory.get_inst(pc.0); - self.instruction_by_pc.entry(pc).or_insert(instruction); - let instruction = Instruction::decode(instruction); + let encoded_instruction = memory.get_inst(pc.0); + instruction_by_pc.entry(pc).or_insert(encoded_instruction); + let instruction = Instruction::decode(encoded_instruction); match instruction { // ret. diff --git a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs index 2ab3b36b..f3b187dd 100644 --- a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs @@ -85,8 +85,11 @@ pub fn adapt_to_stwo_input( memory_segments: &HashMap<&str, MemorySegmentAddresses>, dev_mode: bool, ) -> Result { + let (state_transitions, instruction_by_pc) = + StateTransitions::from_iter(trace_iter, &mut memory, dev_mode); Ok(ProverInput { - state_transitions: StateTransitions::from_iter(trace_iter, &mut memory, dev_mode), + state_transitions, + instruction_by_pc, memory: memory.build(), public_memory_addresses, builtins_segments: BuiltinSegments::from_memory_segments(memory_segments),