Skip to content

Commit

Permalink
Add last lookup boundary constraint.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Jul 14, 2024
1 parent 7af1f0d commit b7dfa16
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 10 deletions.
23 changes: 21 additions & 2 deletions stwo_cairo_prover/src/components/range_check_unit/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ use stwo_prover::trace_generation::{ComponentGen, ComponentTraceGenerator};

pub const RC_Z: &str = "RangeCheckUnit_Z";
pub const RC_COMPONENT_ID: &str = "RC_UNIT";
pub const RC_LOOKUP_VALUE_0: &str = "RC_UNIT_LOOKUP_0";
pub const RC_LOOKUP_VALUE_1: &str = "RC_UNIT_LOOKUP_1";
pub const RC_LOOKUP_VALUE_2: &str = "RC_UNIT_LOOKUP_2";
pub const RC_LOOKUP_VALUE_3: &str = "RC_UNIT_LOOKUP_3";

#[derive(Clone)]
pub struct RangeCheckUnitComponent {
Expand All @@ -39,19 +43,33 @@ impl RangeCheckUnitComponent {
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
constraint_zero_domain: Coset,
lookup_values: &LookupValues,
) {
let z = interaction_elements[RC_Z];
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
let numerator = value * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0];
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);

let lookup_value = SecureField::from_m31(
lookup_values[RC_LOOKUP_VALUE_0],
lookup_values[RC_LOOKUP_VALUE_1],
lookup_values[RC_LOOKUP_VALUE_2],
lookup_values[RC_LOOKUP_VALUE_3],
);
let numerator = value - lookup_value;
let denom = point_vanishing(
constraint_zero_domain.at(constraint_zero_domain.size() - 1),
point,
);
evaluation_accumulator.accumulate(numerator / denom);
}
}

impl Component for RangeCheckUnitComponent {
fn n_constraints(&self) -> usize {
1
2
}

fn max_constraint_log_degree_bound(&self) -> u32 {
Expand Down Expand Up @@ -81,7 +99,7 @@ impl Component for RangeCheckUnitComponent {
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
lookup_values: &LookupValues,
) {
let constraint_zero_domain = CanonicCoset::new(self.log_n_instances).coset;
self.evaluate_lookup_boundary_constraints_at_point(
Expand All @@ -90,6 +108,7 @@ impl Component for RangeCheckUnitComponent {
evaluation_accumulator,
interaction_elements,
constraint_zero_domain,
lookup_values,
);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::collections::BTreeMap;

use itertools::zip_eq;
use stwo_prover::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator};
use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace};
use stwo_prover::core::backend::CpuBackend;
Expand All @@ -11,15 +14,18 @@ use stwo_prover::core::prover::{BASE_TRACE, INTERACTION_TRACE};
use stwo_prover::core::utils::point_vanish_denominator_inverses;
use stwo_prover::core::{InteractionElements, LookupValues};

use super::component::{RangeCheckUnitComponent, RC_Z};
use super::component::{
RangeCheckUnitComponent, RC_LOOKUP_VALUE_0, RC_LOOKUP_VALUE_1, RC_LOOKUP_VALUE_2,
RC_LOOKUP_VALUE_3, RC_Z,
};

impl ComponentProver<CpuBackend> for RangeCheckUnitComponent {
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &ComponentTrace<'_, CpuBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<CpuBackend>,
interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
lookup_values: &LookupValues,
) {
let max_constraint_degree = self.max_constraint_log_degree_bound();
let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain();
Expand All @@ -34,11 +40,44 @@ impl ComponentProver<CpuBackend> for RangeCheckUnitComponent {
zero_domain,
&mut accum,
interaction_elements,
lookup_values,
);
}

fn lookup_values(&self, _trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues {
LookupValues::default()
fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues {
let domain = CanonicCoset::new(self.log_n_instances);
let trace_poly = &trace.polys[INTERACTION_TRACE];
let values = BTreeMap::from_iter([
(
RC_LOOKUP_VALUE_0.to_string(),
trace_poly[0]
.eval_at_point(domain.at(domain.size() - 1).into_ef())
.try_into()
.unwrap(),
),
(
RC_LOOKUP_VALUE_1.to_string(),
trace_poly[1]
.eval_at_point(domain.at(domain.size() - 1).into_ef())
.try_into()
.unwrap(),
),
(
RC_LOOKUP_VALUE_2.to_string(),
trace_poly[2]
.eval_at_point(domain.at(domain.size() - 1).into_ef())
.try_into()
.unwrap(),
),
(
RC_LOOKUP_VALUE_3.to_string(),
trace_poly[3]
.eval_at_point(domain.at(domain.size() - 1).into_ef())
.try_into()
.unwrap(),
),
]);
LookupValues::new(values)
}
}

Expand All @@ -48,15 +87,34 @@ fn evaluate_lookup_boundary_constraints(
zero_domain: Coset,
accum: &mut ColumnAccumulator<'_, CpuBackend>,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
let denom_inverses = point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses = point_vanish_denominator_inverses(
trace_eval_domain,
zero_domain.at(zero_domain.size() - 1),
);
let z = interaction_elements[RC_Z];
for (i, denom_inverse) in denom_inverses.iter().enumerate() {
let lookup_value = SecureField::from_m31(
lookup_values[RC_LOOKUP_VALUE_0],
lookup_values[RC_LOOKUP_VALUE_1],
lookup_values[RC_LOOKUP_VALUE_2],
lookup_values[RC_LOOKUP_VALUE_3],
);
for (i, (first_point_denom_inverse, last_point_denom_inverse)) in
zip_eq(first_point_denom_inverses, last_point_denom_inverses).enumerate()
{
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
let numerator = accum.random_coeff_powers[0]
let first_point_numerator = accum.random_coeff_powers[1]
* (value * (z - trace_evals[BASE_TRACE][0][i]) - trace_evals[BASE_TRACE][1][i]);
accum.accumulate(i, numerator * *denom_inverse);
let last_point_numerator = accum.random_coeff_powers[0] * (value - lookup_value);
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse,
);
}
}

0 comments on commit b7dfa16

Please sign in to comment.