From 32a74728ddab34f9796f53b09e34cebf887b0df0 Mon Sep 17 00:00:00 2001 From: ohad nir Date: Thu, 9 Jan 2025 11:24:46 +0200 Subject: [PATCH] fill memory segments of builtins to the next power of 2 instances --- .../prover/src/input/builtin_segments.rs | 164 ++++++++++- .../crates/prover/src/input/memory.rs | 24 +- .../crates/prover/src/input/vm_import/mod.rs | 278 +++++++++++------- 3 files changed, 351 insertions(+), 115 deletions(-) diff --git a/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs b/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs index 488cd72e..8a2f755c 100644 --- a/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs +++ b/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs @@ -3,6 +3,8 @@ use cairo_vm::stdlib::collections::HashMap; use cairo_vm::types::builtin_name::BuiltinName; use serde::{Deserialize, Serialize}; +use super::memory::MemoryBuilder; + /// This struct holds the builtins used in a Cairo program. #[derive(Debug, Default, Serialize, Deserialize)] pub struct BuiltinSegments { @@ -19,11 +21,9 @@ pub struct BuiltinSegments { } impl BuiltinSegments { - pub fn add_segment( - &mut self, - builtin_name: BuiltinName, - segment: Option, - ) { + /// Sets a segment in the builtin segments. + /// If a segment already exists for the given builtin name, it will be overwritten. + fn set_segment(&mut self, builtin_name: BuiltinName, segment: Option) { match builtin_name { BuiltinName::range_check => self.range_check_bits_128 = segment, BuiltinName::pedersen => self.pedersen = segment, @@ -40,6 +40,86 @@ impl BuiltinSegments { } } + // TODO(ohadn): change return type to non reference once MemorySegmentAddresses implements + // clone. + /// Returns the segment for a given builtin name. + fn get_segment(&self, builtin_name: BuiltinName) -> &Option { + match builtin_name { + BuiltinName::range_check => &self.range_check_bits_128, + BuiltinName::pedersen => &self.pedersen, + BuiltinName::ecdsa => &self.ecdsa, + BuiltinName::keccak => &self.keccak, + BuiltinName::bitwise => &self.bitwise, + BuiltinName::ec_op => &self.ec_op, + BuiltinName::poseidon => &self.poseidon, + BuiltinName::range_check96 => &self.range_check_bits_96, + BuiltinName::add_mod => &self.add_mod, + BuiltinName::mul_mod => &self.mul_mod, + // Not builtins. + BuiltinName::output | BuiltinName::segment_arena => &None, + } + } + + /// Returns the number of memory cells per instance for a given builtin name. + pub fn builtin_memory_cells_per_instance(builtin_name: BuiltinName) -> usize { + match builtin_name { + BuiltinName::range_check => 1, + BuiltinName::pedersen => 3, + BuiltinName::ecdsa => 2, + BuiltinName::keccak => 16, + BuiltinName::bitwise => 5, + BuiltinName::ec_op => 7, + BuiltinName::poseidon => 6, + BuiltinName::range_check96 => 1, + BuiltinName::add_mod => 7, + BuiltinName::mul_mod => 7, + // Not builtins. + BuiltinName::output | BuiltinName::segment_arena => 0, + } + } + + /// Pads a builtin segment with copies of its last instance if that segment isn't None, in + /// which case at least one instance is guaranteed to exist. + /// The segment is padded to the next power of 2 number of instances. + // TODO (ohadn): address the cases of add_mod and mul_mod. + // TODO (ohadn): relocate this function if a more appropriate place is found. + pub fn fill_builtin_segment( + &mut self, + mut memory: MemoryBuilder, + builtin_name: BuiltinName, + ) -> MemoryBuilder { + let &Some(MemorySegmentAddresses { + begin_addr, + stop_ptr, + }) = self.get_segment(builtin_name) + else { + return memory; + }; + let initial_length = stop_ptr - begin_addr; + assert!(initial_length > 0); + let cells_per_instance = Self::builtin_memory_cells_per_instance(builtin_name); + assert!(initial_length % cells_per_instance == 0); + let num_instances = initial_length / cells_per_instance; + let next_power_of_two = num_instances.next_power_of_two(); + let mut instance_to_fill_start = stop_ptr as u32; + for _ in num_instances..next_power_of_two { + memory.copy_segment( + (stop_ptr - cells_per_instance) as u32, + instance_to_fill_start, + cells_per_instance as u32, + ); + instance_to_fill_start += cells_per_instance as u32; + } + self.set_segment( + builtin_name, + Some(MemorySegmentAddresses { + begin_addr, + stop_ptr: begin_addr + cells_per_instance * next_power_of_two, + }), + ); + memory + } + /// Creates a new `BuiltinSegments` struct from a map of memory segment names to addresses. pub fn from_memory_segments(memory_segments: &HashMap<&str, MemorySegmentAddresses>) -> Self { let mut res = BuiltinSegments::default(); @@ -57,7 +137,7 @@ impl BuiltinSegments { ); Some((value.begin_addr, value.stop_ptr).into()) }; - res.add_segment(builtin_name, segment); + res.set_segment(builtin_name, segment); }; } res @@ -69,8 +149,10 @@ impl BuiltinSegments { mod test_builtin_segments { use std::path::PathBuf; - use cairo_vm::air_public_input::PublicInput; + use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput}; + use cairo_vm::types::builtin_name::BuiltinName; + use crate::input::memory::{u128_to_4_limbs, MemoryBuilder, MemoryConfig, MemoryValue}; use crate::input::BuiltinSegments; #[test] @@ -97,4 +179,72 @@ mod test_builtin_segments { Some((7069, 7187).into()) ); } + + /// Initializes a memory builder with the given u128 values. + /// Places the value instance_example[i] at the address (stop_ptr - instance_example.len() + i). + fn initialize_memory(stop_ptr: u64, instance_example: &[u128]) -> MemoryBuilder { + let memory_config = MemoryConfig::default(); + let mut memory_builder = MemoryBuilder::new(memory_config.clone()); + for (i, &value) in instance_example.iter().enumerate() { + let memory_value = if value <= memory_config.small_max { + MemoryValue::Small(value) + } else { + let x = u128_to_4_limbs(value); + MemoryValue::F252([x[0], x[1], x[2], x[3], 0, 0, 0, 0]) + }; + memory_builder.set(stop_ptr - (instance_example.len() - i) as u64, memory_value); + } + memory_builder + } + + #[test] + fn test_fill_builtin_segment() { + let builtin_name = BuiltinName::bitwise; + let instance_example = [ + 123456789, + 4385067362534966725237889432551, + 50448645, + 4385067362534966725237911992050, + 4385067362534966725237962440695, + ]; + let mut builtin_segments = BuiltinSegments::default(); + let cells_per_instance = BuiltinSegments::builtin_memory_cells_per_instance(builtin_name); + assert_eq!(cells_per_instance, instance_example.len()); + let num_instances = 71; + let begin_addr = 23581; + let stop_ptr = begin_addr + cells_per_instance * num_instances; + builtin_segments.set_segment( + builtin_name, + Some(MemorySegmentAddresses { + begin_addr, + stop_ptr, + }), + ); + let mut memory_builder = initialize_memory(stop_ptr as u64, &instance_example); + + memory_builder = builtin_segments.fill_builtin_segment(memory_builder, builtin_name); + + let &MemorySegmentAddresses { + begin_addr: new_begin_addr, + stop_ptr: new_stop_ptr, + } = builtin_segments.get_segment(builtin_name).as_ref().unwrap(); + assert_eq!(new_begin_addr, begin_addr); + let segment_length = new_stop_ptr - new_begin_addr; + assert_eq!(segment_length % cells_per_instance, 0); + let new_num_instances = segment_length / cells_per_instance; + assert_eq!(new_num_instances, 128); + + let memory = memory_builder.build(); + assert_eq!(memory.address_to_id.len(), new_stop_ptr); + + let mut instance_to_verify_start = stop_ptr as u32; + for _ in num_instances..new_num_instances { + memory.assert_identical_segments( + (stop_ptr - cells_per_instance) as u32, + instance_to_verify_start, + cells_per_instance as u32, + ); + instance_to_verify_start += cells_per_instance as u32; + } + } } diff --git a/stwo_cairo_prover/crates/prover/src/input/memory.rs b/stwo_cairo_prover/crates/prover/src/input/memory.rs index 5d814b23..86c924c3 100644 --- a/stwo_cairo_prover/crates/prover/src/input/memory.rs +++ b/stwo_cairo_prover/crates/prover/src/input/memory.rs @@ -30,7 +30,7 @@ pub const P_MIN_2: [u32; 8] = [ 0x0800_0000, ]; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct MemoryConfig { pub small_max: u128, } @@ -97,8 +97,20 @@ impl Memory { values.resize(size, MemoryValue::F252([0; 8])); values.into_iter() } + + pub fn assert_identical_segments( + &self, + src_start_addr: u32, + dst_start_addr: u32, + segment_length: u32, + ) { + for i in 0..segment_length { + assert_eq!(self.get(dst_start_addr + i), self.get(src_start_addr + i)); + } + } } +// TODO(ohadn): derive or impl a default for MemoryBuilder. pub struct MemoryBuilder { memory: Memory, felt252_id_cache: HashMap<[u32; 8], usize>, @@ -168,6 +180,16 @@ impl MemoryBuilder { }); self.address_to_id[addr as usize] = res; } + + pub fn copy_segment(&mut self, src_start_addr: u32, dst_start_addr: u32, segment_length: u32) { + for i in 0..segment_length { + self.set( + (dst_start_addr + i) as u64, + self.memory.get(src_start_addr + i), + ); + } + } + pub fn build(self) -> Memory { self.memory } diff --git a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs index 25ad6379..508dedd5 100644 --- a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs @@ -6,6 +6,7 @@ use std::path::Path; use bytemuck::{bytes_of_mut, Pod, Zeroable}; use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput}; use cairo_vm::stdlib::collections::HashMap; +use cairo_vm::types::builtin_name::BuiltinName; use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry; use json::PrivateInput; use thiserror::Error; @@ -88,12 +89,15 @@ pub fn adapt_to_stwo_input( ) -> Result { let (state_transitions, instruction_by_pc) = StateTransitions::from_iter(trace_iter, &mut memory, dev_mode); + let mut builtins_segments = BuiltinSegments::from_memory_segments(memory_segments); + // TODO (ohadn): fill in the memory segments of the rest of the builtins. + memory = builtins_segments.fill_builtin_segment(memory, BuiltinName::range_check); Ok(ProverInput { state_transitions, instruction_by_pc, memory: memory.build(), public_memory_addresses, - builtins_segments: BuiltinSegments::from_memory_segments(memory_segments), + builtins_segments, }) } @@ -187,113 +191,173 @@ pub mod tests { ) } - #[test] + #[cfg(test)] #[cfg(feature = "slow-tests")] - fn test_read_from_large_files() { - let input = large_cairo_input(); - - // Test opcode components. - let components = input.state_transitions.casm_states_by_opcode; - assert_eq!(components.generic_opcode.len(), 0); - assert_eq!(components.add_ap_opcode.len(), 0); - assert_eq!(components.add_ap_opcode_imm.len(), 36895); - assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 33); - assert_eq!(components.add_opcode_small_imm.len(), 84732); - assert_eq!(components.add_opcode.len(), 189425); - assert_eq!(components.add_opcode_small.len(), 36623); - assert_eq!(components.add_opcode_imm.len(), 22089); - assert_eq!(components.assert_eq_opcode.len(), 233432); - assert_eq!(components.assert_eq_opcode_double_deref.len(), 811061); - assert_eq!(components.assert_eq_opcode_imm.len(), 43184); - assert_eq!(components.call_opcode.len(), 0); - assert_eq!(components.call_opcode_rel.len(), 49439); - assert_eq!(components.call_opcode_op_1_base_fp.len(), 33); - assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 11235); - assert_eq!(components.jnz_opcode.len(), 27032); - assert_eq!(components.jnz_opcode_taken.len(), 51060); - assert_eq!(components.jnz_opcode_dst_base_fp.len(), 5100); - assert_eq!(components.jump_opcode_rel_imm.len(), 31873865); - assert_eq!(components.jump_opcode_rel.len(), 500); - assert_eq!(components.jump_opcode_double_deref.len(), 32); - assert_eq!(components.jump_opcode.len(), 0); - assert_eq!(components.mul_opcode_small_imm.len(), 7234); - assert_eq!(components.mul_opcode_small.len(), 7203); - assert_eq!(components.mul_opcode.len(), 3943); - assert_eq!(components.mul_opcode_imm.len(), 10809); - assert_eq!(components.ret_opcode.len(), 49472); - - // Test builtins. - let builtins_segments = input.builtins_segments; - assert_eq!(builtins_segments.add_mod, None); - assert_eq!(builtins_segments.bitwise, None); - assert_eq!(builtins_segments.ec_op, Some((16428600, 16428747).into())); - assert_eq!(builtins_segments.ecdsa, None); - assert_eq!(builtins_segments.keccak, None); - assert_eq!(builtins_segments.mul_mod, None); - assert_eq!(builtins_segments.pedersen, Some((1322552, 1337489).into())); - assert_eq!( - builtins_segments.poseidon, - Some((16920120, 17444532).into()) - ); - assert_eq!(builtins_segments.range_check_bits_96, None); - assert_eq!( - builtins_segments.range_check_bits_128, - Some((1715768, 1757348).into()) - ); - } + pub mod slow_tests { + use cairo_vm::types::builtin_name::BuiltinName; - #[cfg(feature = "slow-tests")] - #[test] - fn test_read_from_small_files() { - let input = small_cairo_input(); - - // Test opcode components. - let components = input.state_transitions.casm_states_by_opcode; - assert_eq!(components.generic_opcode.len(), 0); - assert_eq!(components.add_ap_opcode.len(), 0); - assert_eq!(components.add_ap_opcode_imm.len(), 2); - assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 1); - assert_eq!(components.add_opcode_small_imm.len(), 500); - assert_eq!(components.add_opcode.len(), 0); - assert_eq!(components.add_opcode_small.len(), 0); - assert_eq!(components.add_opcode_imm.len(), 450); - assert_eq!(components.assert_eq_opcode.len(), 55); - assert_eq!(components.assert_eq_opcode_double_deref.len(), 2100); - assert_eq!(components.assert_eq_opcode_imm.len(), 1952); - assert_eq!(components.call_opcode.len(), 0); - assert_eq!(components.call_opcode_rel.len(), 462); - assert_eq!(components.call_opcode_op_1_base_fp.len(), 0); - assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 450); - assert_eq!(components.jnz_opcode.len(), 0); - assert_eq!(components.jnz_opcode_taken.len(), 0); - assert_eq!(components.jnz_opcode_dst_base_fp.len(), 11); - assert_eq!(components.jump_opcode_rel_imm.len(), 124626); - assert_eq!(components.jump_opcode_rel.len(), 0); - assert_eq!(components.jump_opcode_double_deref.len(), 0); - assert_eq!(components.jump_opcode.len(), 0); - assert_eq!(components.mul_opcode_small_imm.len(), 0); - assert_eq!(components.mul_opcode_small.len(), 0); - assert_eq!(components.mul_opcode.len(), 0); - assert_eq!(components.mul_opcode_imm.len(), 0); - assert_eq!(components.ret_opcode.len(), 462); - - // Test builtins. - let builtins_segments = input.builtins_segments; - assert_eq!(builtins_segments.add_mod, None); - assert_eq!(builtins_segments.bitwise, Some((22512, 22762).into())); - assert_eq!(builtins_segments.ec_op, Some((63472, 63822).into())); - assert_eq!(builtins_segments.ecdsa, Some((22384, 22484).into())); - assert_eq!(builtins_segments.keccak, Some((64368, 65168).into())); - assert_eq!(builtins_segments.mul_mod, None); - assert_eq!(builtins_segments.pedersen, Some((4464, 4614).into())); - assert_eq!(builtins_segments.poseidon, Some((65392, 65692).into())); - assert_eq!( - builtins_segments.range_check_bits_96, - Some((68464, 68514).into()) - ); - assert_eq!( - builtins_segments.range_check_bits_128, - Some((6000, 6050).into()) - ); + use super::*; + use crate::input::memory::{EncodedMemoryValueId, Memory}; + + /// Verifies that the builtin segment is padded with copies of the first instance to the + /// next power of 2 instances. + fn verify_segment_is_padded( + segment: &Option, + builtin_name: BuiltinName, + memory: &Memory, + original_stop_ptr: usize, + golden_instance: &[u32], + ) { + if let Some(segment) = segment { + let cells_per_instance = + BuiltinSegments::builtin_memory_cells_per_instance(builtin_name); + assert_eq!(golden_instance.len(), cells_per_instance); + let segment_length = segment.stop_ptr - segment.begin_addr; + assert_eq!(segment_length % cells_per_instance, 0); + let num_instances = segment_length / cells_per_instance; + // Assert that num_instances is a power of 2. + assert_eq!((num_instances & (num_instances - 1)), 0); + + let original_segment_length = original_stop_ptr - segment.begin_addr; + assert_eq!(original_segment_length % cells_per_instance, 0); + let original_num_instances = original_segment_length / cells_per_instance; + // Assert that num_instances is the next power of 2 after original_num_instances. + assert!(original_num_instances * 2 > num_instances); + assert!(original_num_instances <= num_instances); + + assert!(segment.stop_ptr < memory.address_to_id.len()); + for instance in original_num_instances..num_instances { + for (j, &golden_value) in golden_instance.iter().enumerate() { + let address = segment.begin_addr + instance * cells_per_instance + j; + assert_eq!( + memory.address_to_id[address], + EncodedMemoryValueId(golden_value) + ); + } + } + } + } + + #[test] + fn test_read_from_large_files() { + let input = large_cairo_input(); + + // Test opcode components. + let components = input.state_transitions.casm_states_by_opcode; + assert_eq!(components.generic_opcode.len(), 0); + assert_eq!(components.add_ap_opcode.len(), 0); + assert_eq!(components.add_ap_opcode_imm.len(), 36895); + assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 33); + assert_eq!(components.add_opcode_small_imm.len(), 84732); + assert_eq!(components.add_opcode.len(), 189425); + assert_eq!(components.add_opcode_small.len(), 36623); + assert_eq!(components.add_opcode_imm.len(), 22089); + assert_eq!(components.assert_eq_opcode.len(), 233432); + assert_eq!(components.assert_eq_opcode_double_deref.len(), 811061); + assert_eq!(components.assert_eq_opcode_imm.len(), 43184); + assert_eq!(components.call_opcode.len(), 0); + assert_eq!(components.call_opcode_rel.len(), 49439); + assert_eq!(components.call_opcode_op_1_base_fp.len(), 33); + assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 11235); + assert_eq!(components.jnz_opcode.len(), 27032); + assert_eq!(components.jnz_opcode_taken.len(), 51060); + assert_eq!(components.jnz_opcode_dst_base_fp.len(), 5100); + assert_eq!(components.jump_opcode_rel_imm.len(), 31873865); + assert_eq!(components.jump_opcode_rel.len(), 500); + assert_eq!(components.jump_opcode_double_deref.len(), 32); + assert_eq!(components.jump_opcode.len(), 0); + assert_eq!(components.mul_opcode_small_imm.len(), 7234); + assert_eq!(components.mul_opcode_small.len(), 7203); + assert_eq!(components.mul_opcode.len(), 3943); + assert_eq!(components.mul_opcode_imm.len(), 10809); + assert_eq!(components.ret_opcode.len(), 49472); + + // Test builtins. + let builtins_segments = input.builtins_segments; + assert_eq!(builtins_segments.add_mod, None); + assert_eq!(builtins_segments.bitwise, None); + assert_eq!(builtins_segments.ec_op, Some((16428600, 16428747).into())); + assert_eq!(builtins_segments.ecdsa, None); + assert_eq!(builtins_segments.keccak, None); + assert_eq!(builtins_segments.mul_mod, None); + assert_eq!(builtins_segments.pedersen, Some((1322552, 1337489).into())); + assert_eq!( + builtins_segments.poseidon, + Some((16920120, 17444532).into()) + ); + assert_eq!(builtins_segments.range_check_bits_96, None); + assert_eq!( + builtins_segments.range_check_bits_128, + Some((1715768, 1781304).into()) + ); + verify_segment_is_padded( + &builtins_segments.range_check_bits_128, + BuiltinName::range_check, + &input.memory, + 1757348, + &[117061], + ); + } + + #[test] + fn test_read_from_small_files() { + let input = small_cairo_input(); + + // Test opcode components. + let components = input.state_transitions.casm_states_by_opcode; + assert_eq!(components.generic_opcode.len(), 0); + assert_eq!(components.add_ap_opcode.len(), 0); + assert_eq!(components.add_ap_opcode_imm.len(), 2); + assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 1); + assert_eq!(components.add_opcode_small_imm.len(), 500); + assert_eq!(components.add_opcode.len(), 0); + assert_eq!(components.add_opcode_small.len(), 0); + assert_eq!(components.add_opcode_imm.len(), 450); + assert_eq!(components.assert_eq_opcode.len(), 55); + assert_eq!(components.assert_eq_opcode_double_deref.len(), 2100); + assert_eq!(components.assert_eq_opcode_imm.len(), 1952); + assert_eq!(components.call_opcode.len(), 0); + assert_eq!(components.call_opcode_rel.len(), 462); + assert_eq!(components.call_opcode_op_1_base_fp.len(), 0); + assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 450); + assert_eq!(components.jnz_opcode.len(), 0); + assert_eq!(components.jnz_opcode_taken.len(), 0); + assert_eq!(components.jnz_opcode_dst_base_fp.len(), 11); + assert_eq!(components.jump_opcode_rel_imm.len(), 124626); + assert_eq!(components.jump_opcode_rel.len(), 0); + assert_eq!(components.jump_opcode_double_deref.len(), 0); + assert_eq!(components.jump_opcode.len(), 0); + assert_eq!(components.mul_opcode_small_imm.len(), 0); + assert_eq!(components.mul_opcode_small.len(), 0); + assert_eq!(components.mul_opcode.len(), 0); + assert_eq!(components.mul_opcode_imm.len(), 0); + assert_eq!(components.ret_opcode.len(), 462); + + // Test builtins. + let builtins_segments = input.builtins_segments; + assert_eq!(builtins_segments.add_mod, None); + assert_eq!(builtins_segments.bitwise, Some((22512, 22762).into())); + assert_eq!(builtins_segments.ec_op, Some((63472, 63822).into())); + assert_eq!(builtins_segments.ecdsa, Some((22384, 22484).into())); + assert_eq!(builtins_segments.keccak, Some((64368, 65168).into())); + assert_eq!(builtins_segments.mul_mod, None); + assert_eq!(builtins_segments.pedersen, Some((4464, 4614).into())); + assert_eq!(builtins_segments.poseidon, Some((65392, 65692).into())); + assert_eq!( + builtins_segments.range_check_bits_96, + Some((68464, 68514).into()) + ); + assert_eq!( + builtins_segments.range_check_bits_128, + Some((6000, 6064).into()) + ); + verify_segment_is_padded( + &builtins_segments.range_check_bits_128, + BuiltinName::range_check, + &input.memory, + 6050, + &[7], + ); + } } }