Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fill memory segments of builtins to the next power of 2 instances #340

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 173 additions & 7 deletions stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use cairo_vm::stdlib::collections::HashMap;
use cairo_vm::types::builtin_name::BuiltinName;
use serde::{Deserialize, Serialize};

use super::memory::MemoryBuilder;

// TODO(ohadn): change field types in MemorySegmentAddresses to match address type.
/// This struct holds the builtins used in a Cairo program.
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct BuiltinSegments {
Expand All @@ -19,11 +22,9 @@ pub struct BuiltinSegments {
}

impl BuiltinSegments {
pub fn add_segment(
&mut self,
builtin_name: BuiltinName,
segment: Option<MemorySegmentAddresses>,
) {
/// Sets a segment in the builtin segments.
/// If a segment already exists for the given builtin name, it will be overwritten.
fn set_segment(&mut self, builtin_name: BuiltinName, segment: Option<MemorySegmentAddresses>) {
match builtin_name {
BuiltinName::range_check => self.range_check_bits_128 = segment,
BuiltinName::pedersen => self.pedersen = segment,
Expand All @@ -40,6 +41,85 @@ impl BuiltinSegments {
}
}

// TODO(ohadn): change return type to non reference once MemorySegmentAddresses implements
// clone.
// TODO(ohadn): change output type to match address type.
/// Returns the segment for a given builtin name.
fn get_segment(&self, builtin_name: BuiltinName) -> &Option<MemorySegmentAddresses> {
match builtin_name {
BuiltinName::range_check => &self.range_check_bits_128,
BuiltinName::pedersen => &self.pedersen,
BuiltinName::ecdsa => &self.ecdsa,
BuiltinName::keccak => &self.keccak,
BuiltinName::bitwise => &self.bitwise,
BuiltinName::ec_op => &self.ec_op,
BuiltinName::poseidon => &self.poseidon,
BuiltinName::range_check96 => &self.range_check_bits_96,
BuiltinName::add_mod => &self.add_mod,
BuiltinName::mul_mod => &self.mul_mod,
// Not builtins.
BuiltinName::output | BuiltinName::segment_arena => &None,
}
}

/// Returns the number of memory cells per instance for a given builtin name.
pub fn builtin_memory_cells_per_instance(builtin_name: BuiltinName) -> usize {
match builtin_name {
BuiltinName::range_check => 1,
BuiltinName::pedersen => 3,
BuiltinName::ecdsa => 2,
BuiltinName::keccak => 16,
BuiltinName::bitwise => 5,
BuiltinName::ec_op => 7,
BuiltinName::poseidon => 6,
BuiltinName::range_check96 => 1,
BuiltinName::add_mod => 7,
BuiltinName::mul_mod => 7,
// Not builtins.
BuiltinName::output | BuiltinName::segment_arena => 0,
}
}

/// Pads a builtin segment with copies of its last instance if that segment isn't None, in
/// which case at least one instance is guaranteed to exist.
/// The segment is padded to the next power of 2 number of instances.
/// Note: the last instance was already verified as valid by the VM and in the case of add_mod
/// and mul_mod, security checks have verified that instance has n=1. Thus the padded segment
/// satisfies all the AIR constraints.
// TODO (ohadn): relocate this function if a more appropriate place is found.
pub fn fill_builtin_segment(&mut self, memory: &mut MemoryBuilder, builtin_name: BuiltinName) {
let &Some(MemorySegmentAddresses {
begin_addr,
stop_ptr,
}) = self.get_segment(builtin_name)
else {
return;
};
let initial_length = stop_ptr - begin_addr;
assert!(initial_length > 0);
let cells_per_instance = Self::builtin_memory_cells_per_instance(builtin_name);
assert!(initial_length % cells_per_instance == 0);
let num_instances = initial_length / cells_per_instance;
let next_power_of_two = num_instances.next_power_of_two();
let mut instance_to_fill_start = stop_ptr as u32;
let last_instance_start = (stop_ptr - cells_per_instance) as u32;
for _ in num_instances..next_power_of_two {
memory.copy_block(
last_instance_start,
instance_to_fill_start,
cells_per_instance as u32,
);
instance_to_fill_start += cells_per_instance as u32;
}
self.set_segment(
builtin_name,
Some(MemorySegmentAddresses {
begin_addr,
stop_ptr: begin_addr + cells_per_instance * next_power_of_two,
}),
);
}

/// Creates a new `BuiltinSegments` struct from a map of memory segment names to addresses.
pub fn from_memory_segments(memory_segments: &HashMap<&str, MemorySegmentAddresses>) -> Self {
let mut res = BuiltinSegments::default();
Expand All @@ -57,7 +137,7 @@ impl BuiltinSegments {
);
Some((value.begin_addr, value.stop_ptr).into())
};
res.add_segment(builtin_name, segment);
res.set_segment(builtin_name, segment);
};
}
res
Expand All @@ -69,10 +149,25 @@ impl BuiltinSegments {
mod test_builtin_segments {
use std::path::PathBuf;

use cairo_vm::air_public_input::PublicInput;
use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput};
use cairo_vm::types::builtin_name::BuiltinName;

use crate::input::memory::{u128_to_4_limbs, Memory, MemoryBuilder, MemoryConfig, MemoryValue};
use crate::input::BuiltinSegments;

/// Asserts that the values at addresses start_addr1 to start_addr1 + segment_length - 1
/// are equal to values at the addresses start_addr2 to start_addr2 + segment_length - 1.
pub fn assert_identical_blocks(
memory: &Memory,
start_addr1: u32,
start_addr2: u32,
segment_length: u32,
) {
for i in 0..segment_length {
assert_eq!(memory.get(start_addr1 + i), memory.get(start_addr2 + i));
}
}

#[test]
fn test_builtin_segments() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
Expand All @@ -97,4 +192,75 @@ mod test_builtin_segments {
Some((7069, 7187).into())
);
}

/// Initializes a memory builder with the given u128 values.
/// Places the value instance_example[i] at the address memory_write_start + i.
fn initialize_memory(memory_write_start: u64, instance_example: &[u128]) -> MemoryBuilder {
let memory_config = MemoryConfig::default();
let mut memory_builder = MemoryBuilder::new(memory_config.clone());
for (i, &value) in instance_example.iter().enumerate() {
let memory_value = if value <= memory_config.small_max {
MemoryValue::Small(value)
} else {
let x = u128_to_4_limbs(value);
MemoryValue::F252([x[0], x[1], x[2], x[3], 0, 0, 0, 0])
};
memory_builder.set(memory_write_start + i as u64, memory_value);
}
memory_builder
}

#[test]
fn test_fill_builtin_segment() {
let builtin_name = BuiltinName::bitwise;
let instance_example = [
123456789,
4385067362534966725237889432551,
50448645,
4385067362534966725237911992050,
4385067362534966725237962440695,
];
let mut builtin_segments = BuiltinSegments::default();
let cells_per_instance = BuiltinSegments::builtin_memory_cells_per_instance(builtin_name);
assert_eq!(cells_per_instance, instance_example.len());
let num_instances = 71;
let begin_addr = 23581;
let stop_ptr = begin_addr + cells_per_instance * num_instances;
builtin_segments.set_segment(
builtin_name,
Some(MemorySegmentAddresses {
begin_addr,
stop_ptr,
}),
);
let memory_write_start = (stop_ptr - cells_per_instance) as u64;
let mut memory_builder = initialize_memory(memory_write_start, &instance_example);

builtin_segments.fill_builtin_segment(&mut memory_builder, builtin_name);

let &MemorySegmentAddresses {
begin_addr: new_begin_addr,
stop_ptr: new_stop_ptr,
} = builtin_segments.get_segment(builtin_name).as_ref().unwrap();
assert_eq!(new_begin_addr, begin_addr);
let segment_length = new_stop_ptr - new_begin_addr;
assert_eq!(segment_length % cells_per_instance, 0);
let new_num_instances = segment_length / cells_per_instance;
assert_eq!(new_num_instances, 128);

let memory = memory_builder.build();
assert_eq!(memory.address_to_id.len(), new_stop_ptr);

let mut instance_to_verify_start = stop_ptr as u32;
let last_instance_start = (stop_ptr - cells_per_instance) as u32;
for _ in num_instances..new_num_instances {
assert_identical_blocks(
&memory,
last_instance_start,
instance_to_verify_start,
cells_per_instance as u32,
);
instance_to_verify_start += cells_per_instance as u32;
}
}
}
17 changes: 16 additions & 1 deletion stwo_cairo_prover/crates/prover/src/input/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub const P_MIN_2: [u32; 8] = [
0x0800_0000,
];

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct MemoryConfig {
pub small_max: u128,
}
Expand Down Expand Up @@ -99,6 +99,7 @@ impl Memory {
}
}

// TODO(ohadn): derive or impl a default for MemoryBuilder.
pub struct MemoryBuilder {
memory: Memory,
felt252_id_cache: HashMap<[u32; 8], usize>,
Expand Down Expand Up @@ -143,6 +144,7 @@ impl MemoryBuilder {
res
}

// TODO(ohadn): settle on an address integer type, and use it consistently.
pub fn set(&mut self, addr: u64, value: MemoryValue) {
if addr as usize >= self.address_to_id.len() {
self.address_to_id
Expand All @@ -168,6 +170,19 @@ impl MemoryBuilder {
});
self.address_to_id[addr as usize] = res;
}

/// Copies a block of memory from one location to another.
/// The values at addresses src_start_addr to src_start_addr + segment_length - 1 are copied to
/// the addresses dst_start_addr to dst_start_addr + segment_length - 1.
pub fn copy_block(&mut self, src_start_addr: u32, dst_start_addr: u32, segment_length: u32) {
for i in 0..segment_length {
self.set(
(dst_start_addr + i) as u64,
self.memory.get(src_start_addr + i),
);
}
}

pub fn build(self) -> Memory {
self.memory
}
Expand Down
Loading
Loading