diff --git a/stwo_cairo_prover/src/components/memory/component.rs b/stwo_cairo_prover/src/components/memory/component.rs index d892ee3c..9230c968 100644 --- a/stwo_cairo_prover/src/components/memory/component.rs +++ b/stwo_cairo_prover/src/components/memory/component.rs @@ -7,15 +7,23 @@ use stwo_prover::core::backend::CpuBackend; use stwo_prover::core::circle::CirclePoint; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fields::qm31::SecureField; -use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use stwo_prover::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE}; +use stwo_prover::core::fields::FieldExpOps; use stwo_prover::core::pcs::TreeVec; use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::utils::{ + bit_reverse_index, coset_order_to_circle_domain_order_index, shifted_secure_combination, +}; use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues}; use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; use stwo_prover::trace_generation::{ComponentGen, ComponentTraceGenerator}; +pub const MEMORY_ALPHA: &str = "MEMORY_ALPHA"; +pub const MEMORY_Z: &str = "MEMORY_Z"; + const N_M31_IN_FELT252: usize = 21; +const MULTIPLICITY_COLUMN: usize = 22; const LOG_MEMORY_ADDRESS_BOUND: u32 = 20; const MEMORY_ADDRESS_BOUND: usize = 1 << LOG_MEMORY_ADDRESS_BOUND; @@ -81,7 +89,7 @@ impl ComponentTraceGenerator for MemoryTraceGenerator { for (j, value) in values.iter().enumerate() { trace[j + 1][i] = BaseField::from_u32_unchecked(value.0); } - trace[22][i] = BaseField::from_u32_unchecked(*multiplicity); + trace[MULTIPLICITY_COLUMN][i] = BaseField::from_u32_unchecked(*multiplicity); } let domain = CanonicCoset::new(LOG_MEMORY_ADDRESS_BOUND).circle_domain(); @@ -93,10 +101,39 @@ impl ComponentTraceGenerator for MemoryTraceGenerator { fn write_interaction_trace( &self, - _trace: &ColumnVec<&CircleEvaluation>, - _elements: &InteractionElements, + trace: &ColumnVec<&CircleEvaluation>, + elements: &InteractionElements, ) -> ColumnVec> { - todo!() + let interaction_trace_domain = trace[0].domain; + let (alpha, z) = (elements[MEMORY_ALPHA], elements[MEMORY_Z]); + + let addresses_and_values: Vec<[BaseField; N_M31_IN_FELT252 + 1]> = (0 + ..MEMORY_ADDRESS_BOUND) + .map(|i| std::array::from_fn(|j| trace[j].values[i])) + .collect_vec(); + let denoms = addresses_and_values + .iter() + .map(|address_and_value| shifted_secure_combination(address_and_value, alpha, z)) + .collect_vec(); + let mut denom_inverses = vec![SecureField::zero(); denoms.len()]; + SecureField::batch_inverse(&denoms, &mut denom_inverses); + let mut logup_values = vec![SecureField::zero(); trace[MULTIPLICITY_COLUMN].values.len()]; + let mut last = SecureField::zero(); + let log_size = interaction_trace_domain.log_size(); + for i in 0..trace[MULTIPLICITY_COLUMN].values.len() { + let index = coset_order_to_circle_domain_order_index(i, log_size); + let index = bit_reverse_index(index, log_size); + let interaction_value = + last + (denom_inverses[index] * trace[MULTIPLICITY_COLUMN].values[index]); + logup_values[index] = interaction_value; + last = interaction_value; + } + let secure_column: SecureColumn = logup_values.into_iter().collect(); + secure_column + .columns + .into_iter() + .map(|eval| CircleEvaluation::new(interaction_trace_domain, eval)) + .collect_vec() } fn component(&self) -> Self::Component {