Skip to content

Commit

Permalink
Implement memory component prover. (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 authored Jul 17, 2024
2 parents a7a2183 + 8e528d0 commit 5c6a408
Show file tree
Hide file tree
Showing 3 changed files with 358 additions and 16 deletions.
76 changes: 60 additions & 16 deletions stwo_cairo_prover/src/components/memory/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use stwo_prover::core::air::mask::fixed_mask_points;
use stwo_prover::core::air::Component;
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::circle::CirclePoint;
use stwo_prover::core::fields::m31::{BaseField, M31};
use stwo_prover::core::constraints::{coset_vanishing, point_excluder, point_vanishing};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
use stwo_prover::core::fields::FieldExpOps;
Expand All @@ -17,24 +18,33 @@ use stwo_prover::core::utils::{
};
use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues};
use stwo_prover::trace_generation::registry::ComponentGenerationRegistry;
use stwo_prover::trace_generation::{ComponentGen, ComponentTraceGenerator};
use stwo_prover::trace_generation::{
ComponentGen, ComponentTraceGenerator, BASE_TRACE, INTERACTION_TRACE,
};

pub const MEMORY_ALPHA: &str = "MEMORY_ALPHA";
pub const MEMORY_Z: &str = "MEMORY_Z";

const N_M31_IN_FELT252: usize = 21;
const MULTIPLICITY_COLUMN: usize = 22;
const LOG_MEMORY_ADDRESS_BOUND: u32 = 20;
const MEMORY_ADDRESS_BOUND: usize = 1 << LOG_MEMORY_ADDRESS_BOUND;
pub const MEMORY_COMPONENT_ID: &str = "MEMORY";
pub const MEMORY_LOOKUP_VALUE_0: &str = "MEMORY_LOOKUP_0";
pub const MEMORY_LOOKUP_VALUE_1: &str = "MEMORY_LOOKUP_1";
pub const MEMORY_LOOKUP_VALUE_2: &str = "MEMORY_LOOKUP_2";
pub const MEMORY_LOOKUP_VALUE_3: &str = "MEMORY_LOOKUP_3";

pub const N_M31_IN_FELT252: usize = 21;
pub const MULTIPLICITY_COLUMN: usize = 22;
// TODO(AlonH): Make memory size configurable.
pub const LOG_MEMORY_ADDRESS_BOUND: u32 = 3;
pub const MEMORY_ADDRESS_BOUND: usize = 1 << LOG_MEMORY_ADDRESS_BOUND;

/// Addresses are continuous and start from 0.
/// Values are Felt252 stored as `N_M31_IN_FELT252` M31 values (each value contain 12 bits).
pub struct MemoryTraceGenerator {
// TODO(AlonH): Consider to change values to be Felt252.
pub values: Vec<[M31; N_M31_IN_FELT252]>,
pub values: Vec<[BaseField; N_M31_IN_FELT252]>,
pub multiplicities: Vec<u32>,
}

#[derive(Clone)]
pub struct MemoryComponent {
pub log_n_rows: u32,
}
Expand All @@ -48,7 +58,7 @@ impl MemoryComponent {
impl MemoryTraceGenerator {
pub fn new(_path: String) -> Self {
// TODO(AlonH): change to read from file.
let values = vec![[M31::zero(); N_M31_IN_FELT252]; MEMORY_ADDRESS_BOUND];
let values = vec![[BaseField::zero(); N_M31_IN_FELT252]; MEMORY_ADDRESS_BOUND];
let multiplicities = vec![0; MEMORY_ADDRESS_BOUND];
Self {
values,
Expand All @@ -61,7 +71,7 @@ impl ComponentGen for MemoryTraceGenerator {}

impl ComponentTraceGenerator<CpuBackend> for MemoryTraceGenerator {
type Component = MemoryComponent;
type Inputs = M31;
type Inputs = BaseField;

fn add_inputs(&mut self, inputs: &Self::Inputs) {
let input = inputs.0 as usize;
Expand Down Expand Up @@ -175,12 +185,46 @@ impl Component for MemoryComponent {

fn evaluate_constraint_quotients_at_point(
&self,
_point: CirclePoint<SecureField>,
_mask: &TreeVec<Vec<Vec<SecureField>>>,
_evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
point: CirclePoint<SecureField>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
todo!()
// First lookup point boundary constraint.
let constraint_zero_domain = CanonicCoset::new(self.log_n_rows).coset;
let (alpha, z) = (
interaction_elements[MEMORY_ALPHA],
interaction_elements[MEMORY_Z],
);
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
let address_and_value: [SecureField; N_M31_IN_FELT252 + 1] =
std::array::from_fn(|i| mask[BASE_TRACE][i][0]);
let numerator = value * shifted_secure_combination(&address_and_value, alpha, z)
- mask[BASE_TRACE][MULTIPLICITY_COLUMN][0];
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);

// Last lookup point boundary constraint.
let lookup_value = SecureField::from_m31(
lookup_values[MEMORY_LOOKUP_VALUE_0],
lookup_values[MEMORY_LOOKUP_VALUE_1],
lookup_values[MEMORY_LOOKUP_VALUE_2],
lookup_values[MEMORY_LOOKUP_VALUE_3],
);
let numerator = value - lookup_value;
let denom = point_vanishing(constraint_zero_domain.at(1), point);
evaluation_accumulator.accumulate(numerator / denom);

// Lookup step constraint.
let prev_value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1]));
let numerator = (value - prev_value)
* shifted_secure_combination(&address_and_value, alpha, z)
- mask[BASE_TRACE][22][0];
let denom = coset_vanishing(constraint_zero_domain, point)
/ point_excluder(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
}
}
137 changes: 137 additions & 0 deletions stwo_cairo_prover/src/components/memory/component_prover.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use std::collections::BTreeMap;

use itertools::izip;
use num_traits::Zero;
use stwo_prover::core::air::accumulation::DomainEvaluationAccumulator;
use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace};
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::constraints::{coset_vanishing, point_excluder};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::FieldExpOps;
use stwo_prover::core::poly::circle::CanonicCoset;
use stwo_prover::core::utils::{
bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index,
shifted_secure_combination,
};
use stwo_prover::core::{InteractionElements, LookupValues};
use stwo_prover::trace_generation::{BASE_TRACE, INTERACTION_TRACE};

use super::component::{
MemoryComponent, MEMORY_ALPHA, MEMORY_LOOKUP_VALUE_0, MEMORY_LOOKUP_VALUE_1,
MEMORY_LOOKUP_VALUE_2, MEMORY_LOOKUP_VALUE_3, MEMORY_Z, MULTIPLICITY_COLUMN, N_M31_IN_FELT252,
};

impl ComponentProver<CpuBackend> for MemoryComponent {
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &ComponentTrace<'_, CpuBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<CpuBackend>,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
let max_constraint_degree = self.max_constraint_log_degree_bound();
let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain();
let trace_evals = &trace.evals;
let zero_domain = CanonicCoset::new(self.log_n_rows).coset;
let [mut accum] =
evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]);

// TODO(AlonH): Get all denominators in one loop and don't perform unnecessary inversions.
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1));
let mut step_denoms = vec![];
for point in trace_eval_domain.iter() {
step_denoms.push(
coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point),
);
}
bit_reverse(&mut step_denoms);
let mut step_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&step_denoms, &mut step_denom_inverses);
let (alpha, z) = (
interaction_elements[MEMORY_ALPHA],
interaction_elements[MEMORY_Z],
);
let lookup_value = SecureField::from_m31(
lookup_values[MEMORY_LOOKUP_VALUE_0],
lookup_values[MEMORY_LOOKUP_VALUE_1],
lookup_values[MEMORY_LOOKUP_VALUE_2],
lookup_values[MEMORY_LOOKUP_VALUE_3],
);
for (i, (first_point_denom_inverse, last_point_denom_inverse, step_denom_inverse)) in izip!(
first_point_denom_inverses,
last_point_denom_inverses,
step_denom_inverses,
)
.enumerate()
{
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
let prev_index = previous_bit_reversed_circle_domain_index(
i,
zero_domain.log_size,
trace_eval_domain.log_size(),
);
let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][prev_index]
}));
let address_and_value: [BaseField; N_M31_IN_FELT252 + 1] =
std::array::from_fn(|j| trace_evals[BASE_TRACE][j][i]);

let first_point_numerator = accum.random_coeff_powers[2]
* (value * shifted_secure_combination(&address_and_value, alpha, z)
- trace_evals[BASE_TRACE][MULTIPLICITY_COLUMN][i]);

let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value);
let step_numerator = accum.random_coeff_powers[0]
* ((value - prev_value) * shifted_secure_combination(&address_and_value, alpha, z)
- trace_evals[BASE_TRACE][MULTIPLICITY_COLUMN][i]);
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse
+ step_numerator * step_denom_inverse,
);
}
}

fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues {
let domain = CanonicCoset::new(self.log_n_rows);
let trace_poly = &trace.polys[INTERACTION_TRACE];
let values = BTreeMap::from_iter([
(
MEMORY_LOOKUP_VALUE_0.to_string(),
trace_poly[0]
.eval_at_point(domain.at(1).into_ef())
.try_into()
.unwrap(),
),
(
MEMORY_LOOKUP_VALUE_1.to_string(),
trace_poly[1]
.eval_at_point(domain.at(1).into_ef())
.try_into()
.unwrap(),
),
(
MEMORY_LOOKUP_VALUE_2.to_string(),
trace_poly[2]
.eval_at_point(domain.at(1).into_ef())
.try_into()
.unwrap(),
),
(
MEMORY_LOOKUP_VALUE_3.to_string(),
trace_poly[3]
.eval_at_point(domain.at(1).into_ef())
.try_into()
.unwrap(),
),
]);
LookupValues::new(values)
}
}
Loading

0 comments on commit 5c6a408

Please sign in to comment.