Skip to content

Commit

Permalink
fill memory segments of builtins to the next power of 2 instances
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-nir-starkware committed Jan 16, 2025
1 parent edffee9 commit 32a7472
Show file tree
Hide file tree
Showing 3 changed files with 351 additions and 115 deletions.
164 changes: 157 additions & 7 deletions stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use cairo_vm::stdlib::collections::HashMap;
use cairo_vm::types::builtin_name::BuiltinName;
use serde::{Deserialize, Serialize};

use super::memory::MemoryBuilder;

/// This struct holds the builtins used in a Cairo program.
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct BuiltinSegments {
Expand All @@ -19,11 +21,9 @@ pub struct BuiltinSegments {
}

impl BuiltinSegments {
pub fn add_segment(
&mut self,
builtin_name: BuiltinName,
segment: Option<MemorySegmentAddresses>,
) {
/// Sets a segment in the builtin segments.
/// If a segment already exists for the given builtin name, it will be overwritten.
fn set_segment(&mut self, builtin_name: BuiltinName, segment: Option<MemorySegmentAddresses>) {
match builtin_name {
BuiltinName::range_check => self.range_check_bits_128 = segment,
BuiltinName::pedersen => self.pedersen = segment,
Expand All @@ -40,6 +40,86 @@ impl BuiltinSegments {
}
}

// TODO(ohadn): change return type to non reference once MemorySegmentAddresses implements
// clone.
/// Returns the segment for a given builtin name.
fn get_segment(&self, builtin_name: BuiltinName) -> &Option<MemorySegmentAddresses> {
match builtin_name {
BuiltinName::range_check => &self.range_check_bits_128,
BuiltinName::pedersen => &self.pedersen,
BuiltinName::ecdsa => &self.ecdsa,
BuiltinName::keccak => &self.keccak,
BuiltinName::bitwise => &self.bitwise,
BuiltinName::ec_op => &self.ec_op,
BuiltinName::poseidon => &self.poseidon,
BuiltinName::range_check96 => &self.range_check_bits_96,
BuiltinName::add_mod => &self.add_mod,
BuiltinName::mul_mod => &self.mul_mod,
// Not builtins.
BuiltinName::output | BuiltinName::segment_arena => &None,
}
}

/// Returns the number of memory cells per instance for a given builtin name.
pub fn builtin_memory_cells_per_instance(builtin_name: BuiltinName) -> usize {
match builtin_name {
BuiltinName::range_check => 1,
BuiltinName::pedersen => 3,
BuiltinName::ecdsa => 2,
BuiltinName::keccak => 16,
BuiltinName::bitwise => 5,
BuiltinName::ec_op => 7,
BuiltinName::poseidon => 6,
BuiltinName::range_check96 => 1,
BuiltinName::add_mod => 7,
BuiltinName::mul_mod => 7,
// Not builtins.
BuiltinName::output | BuiltinName::segment_arena => 0,
}
}

/// Pads a builtin segment with copies of its last instance if that segment isn't None, in
/// which case at least one instance is guaranteed to exist.
/// The segment is padded to the next power of 2 number of instances.
// TODO (ohadn): address the cases of add_mod and mul_mod.
// TODO (ohadn): relocate this function if a more appropriate place is found.
pub fn fill_builtin_segment(
&mut self,
mut memory: MemoryBuilder,
builtin_name: BuiltinName,
) -> MemoryBuilder {
let &Some(MemorySegmentAddresses {
begin_addr,
stop_ptr,
}) = self.get_segment(builtin_name)
else {
return memory;
};
let initial_length = stop_ptr - begin_addr;
assert!(initial_length > 0);
let cells_per_instance = Self::builtin_memory_cells_per_instance(builtin_name);
assert!(initial_length % cells_per_instance == 0);
let num_instances = initial_length / cells_per_instance;
let next_power_of_two = num_instances.next_power_of_two();
let mut instance_to_fill_start = stop_ptr as u32;
for _ in num_instances..next_power_of_two {
memory.copy_segment(
(stop_ptr - cells_per_instance) as u32,
instance_to_fill_start,
cells_per_instance as u32,
);
instance_to_fill_start += cells_per_instance as u32;
}
self.set_segment(
builtin_name,
Some(MemorySegmentAddresses {
begin_addr,
stop_ptr: begin_addr + cells_per_instance * next_power_of_two,
}),
);
memory
}

/// Creates a new `BuiltinSegments` struct from a map of memory segment names to addresses.
pub fn from_memory_segments(memory_segments: &HashMap<&str, MemorySegmentAddresses>) -> Self {
let mut res = BuiltinSegments::default();
Expand All @@ -57,7 +137,7 @@ impl BuiltinSegments {
);
Some((value.begin_addr, value.stop_ptr).into())
};
res.add_segment(builtin_name, segment);
res.set_segment(builtin_name, segment);
};
}
res
Expand All @@ -69,8 +149,10 @@ impl BuiltinSegments {
mod test_builtin_segments {
use std::path::PathBuf;

use cairo_vm::air_public_input::PublicInput;
use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput};
use cairo_vm::types::builtin_name::BuiltinName;

use crate::input::memory::{u128_to_4_limbs, MemoryBuilder, MemoryConfig, MemoryValue};
use crate::input::BuiltinSegments;

#[test]
Expand All @@ -97,4 +179,72 @@ mod test_builtin_segments {
Some((7069, 7187).into())
);
}

/// Initializes a memory builder with the given u128 values.
/// Places the value instance_example[i] at the address (stop_ptr - instance_example.len() + i).
fn initialize_memory(stop_ptr: u64, instance_example: &[u128]) -> MemoryBuilder {
let memory_config = MemoryConfig::default();
let mut memory_builder = MemoryBuilder::new(memory_config.clone());
for (i, &value) in instance_example.iter().enumerate() {
let memory_value = if value <= memory_config.small_max {
MemoryValue::Small(value)
} else {
let x = u128_to_4_limbs(value);
MemoryValue::F252([x[0], x[1], x[2], x[3], 0, 0, 0, 0])
};
memory_builder.set(stop_ptr - (instance_example.len() - i) as u64, memory_value);
}
memory_builder
}

#[test]
fn test_fill_builtin_segment() {
let builtin_name = BuiltinName::bitwise;
let instance_example = [
123456789,
4385067362534966725237889432551,
50448645,
4385067362534966725237911992050,
4385067362534966725237962440695,
];
let mut builtin_segments = BuiltinSegments::default();
let cells_per_instance = BuiltinSegments::builtin_memory_cells_per_instance(builtin_name);
assert_eq!(cells_per_instance, instance_example.len());
let num_instances = 71;
let begin_addr = 23581;
let stop_ptr = begin_addr + cells_per_instance * num_instances;
builtin_segments.set_segment(
builtin_name,
Some(MemorySegmentAddresses {
begin_addr,
stop_ptr,
}),
);
let mut memory_builder = initialize_memory(stop_ptr as u64, &instance_example);

memory_builder = builtin_segments.fill_builtin_segment(memory_builder, builtin_name);

let &MemorySegmentAddresses {
begin_addr: new_begin_addr,
stop_ptr: new_stop_ptr,
} = builtin_segments.get_segment(builtin_name).as_ref().unwrap();
assert_eq!(new_begin_addr, begin_addr);
let segment_length = new_stop_ptr - new_begin_addr;
assert_eq!(segment_length % cells_per_instance, 0);
let new_num_instances = segment_length / cells_per_instance;
assert_eq!(new_num_instances, 128);

let memory = memory_builder.build();
assert_eq!(memory.address_to_id.len(), new_stop_ptr);

let mut instance_to_verify_start = stop_ptr as u32;
for _ in num_instances..new_num_instances {
memory.assert_identical_segments(
(stop_ptr - cells_per_instance) as u32,
instance_to_verify_start,
cells_per_instance as u32,
);
instance_to_verify_start += cells_per_instance as u32;
}
}
}
24 changes: 23 additions & 1 deletion stwo_cairo_prover/crates/prover/src/input/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub const P_MIN_2: [u32; 8] = [
0x0800_0000,
];

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct MemoryConfig {
pub small_max: u128,
}
Expand Down Expand Up @@ -97,8 +97,20 @@ impl Memory {
values.resize(size, MemoryValue::F252([0; 8]));
values.into_iter()
}

pub fn assert_identical_segments(
&self,
src_start_addr: u32,
dst_start_addr: u32,
segment_length: u32,
) {
for i in 0..segment_length {
assert_eq!(self.get(dst_start_addr + i), self.get(src_start_addr + i));
}
}
}

// TODO(ohadn): derive or impl a default for MemoryBuilder.
pub struct MemoryBuilder {
memory: Memory,
felt252_id_cache: HashMap<[u32; 8], usize>,
Expand Down Expand Up @@ -168,6 +180,16 @@ impl MemoryBuilder {
});
self.address_to_id[addr as usize] = res;
}

pub fn copy_segment(&mut self, src_start_addr: u32, dst_start_addr: u32, segment_length: u32) {
for i in 0..segment_length {
self.set(
(dst_start_addr + i) as u64,
self.memory.get(src_start_addr + i),
);
}
}

pub fn build(self) -> Memory {
self.memory
}
Expand Down
Loading

0 comments on commit 32a7472

Please sign in to comment.