Skip to content

Commit

Permalink
Implement ComponentProver for RangeCheckUnitComponent.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Jul 14, 2024
1 parent d3cfe07 commit 7af1f0d
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 47 deletions.
42 changes: 21 additions & 21 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion stwo_cairo_prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ edition = "2021"
itertools = "0.12.0"
num-traits = "0.2.17"
# TODO(ShaharS): take stwo version from the source repository.
stwo-prover = { git = "https://github.com/starkware-libs/stwo", branch = "dev" }
stwo-prover = { git = "https://github.com/starkware-libs/stwo", rev = "2501444" }
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ use stwo_prover::core::fields::FieldExpOps;
use stwo_prover::core::pcs::TreeVec;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::prover::{BASE_TRACE, INTERACTION_TRACE};
use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues};
use stwo_prover::trace_generation::registry::ComponentGenerationRegistry;
use stwo_prover::trace_generation::{ComponentGen, ComponentTraceGenerator};

pub const RC_Z: &str = "RangeCheckUnit_Z";
pub const RC_COMPONENT_ID: &str = "RC_UNIT";

#[derive(Clone)]
pub struct RangeCheckUnitComponent {
pub log_n_instances: u32,
}
Expand All @@ -39,8 +41,9 @@ impl RangeCheckUnitComponent {
constraint_zero_domain: Coset,
) {
let z = interaction_elements[RC_Z];
let value = SecureField::from_partial_evals(std::array::from_fn(|i| mask[1][i][0]));
let numerator = value * (z - mask[0][0][0]) - mask[0][1][0];
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
let numerator = value * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0];
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
}
Expand All @@ -55,10 +58,6 @@ impl Component for RangeCheckUnitComponent {
self.log_n_instances + 1
}

fn n_interaction_phases(&self) -> u32 {
2
}

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(vec![
vec![self.log_n_instances; 2],
Expand All @@ -70,7 +69,10 @@ impl Component for RangeCheckUnitComponent {
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
TreeVec::new(vec![fixed_mask_points(&vec![vec![0_usize]], point)])
TreeVec::new(vec![
fixed_mask_points(&vec![vec![0_usize]; 2], point),
vec![vec![point]; SECURE_EXTENSION_DEGREE],
])
}

fn evaluate_constraint_quotients_at_point(
Expand Down Expand Up @@ -177,28 +179,12 @@ impl ComponentTraceGenerator<CpuBackend> for RangeCheckUnitTraceGenerator {
#[cfg(test)]
mod tests {
use super::*;
use crate::components::range_check_unit::tests::register_test_rc;

#[test]
fn test_rc_unit_trace() {
let mut registry = ComponentGenerationRegistry::default();
registry.register(RC_COMPONENT_ID, RangeCheckUnitTraceGenerator::new(8));
let inputs = vec![
vec![BaseField::from_u32_unchecked(0); 3],
vec![BaseField::from_u32_unchecked(1); 1],
vec![BaseField::from_u32_unchecked(2); 2],
vec![BaseField::from_u32_unchecked(3); 5],
vec![BaseField::from_u32_unchecked(4); 10],
vec![BaseField::from_u32_unchecked(5); 1],
vec![BaseField::from_u32_unchecked(6); 0],
vec![BaseField::from_u32_unchecked(7); 1],
]
.into_iter()
.flatten()
.collect_vec();
registry
.get_generator_mut::<RangeCheckUnitTraceGenerator>(RC_COMPONENT_ID)
.add_inputs(&inputs);

register_test_rc(&mut registry);
let trace = RangeCheckUnitTraceGenerator::write_trace(RC_COMPONENT_ID, &mut registry);
let random_value = SecureField::from_u32_unchecked(1, 2, 3, 117);
let interaction_elements =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use stwo_prover::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator};
use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace};
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::circle::Coset;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::pcs::TreeVec;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::prover::{BASE_TRACE, INTERACTION_TRACE};
use stwo_prover::core::utils::point_vanish_denominator_inverses;
use stwo_prover::core::{InteractionElements, LookupValues};

use super::component::{RangeCheckUnitComponent, RC_Z};

impl ComponentProver<CpuBackend> for RangeCheckUnitComponent {
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &ComponentTrace<'_, CpuBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<CpuBackend>,
interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
let max_constraint_degree = self.max_constraint_log_degree_bound();
let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain();
let trace_evals = &trace.evals;
let zero_domain = CanonicCoset::new(self.log_n_instances).coset;
let [mut accum] =
evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]);

evaluate_lookup_boundary_constraints(
trace_evals,
trace_eval_domain,
zero_domain,
&mut accum,
interaction_elements,
);
}

fn lookup_values(&self, _trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues {
LookupValues::default()
}
}

fn evaluate_lookup_boundary_constraints(
trace_evals: &TreeVec<Vec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>>,
trace_eval_domain: CircleDomain,
zero_domain: Coset,
accum: &mut ColumnAccumulator<'_, CpuBackend>,
interaction_elements: &InteractionElements,
) {
let denom_inverses = point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let z = interaction_elements[RC_Z];
for (i, denom_inverse) in denom_inverses.iter().enumerate() {
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
let numerator = accum.random_coeff_powers[0]
* (value * (z - trace_evals[BASE_TRACE][0][i]) - trace_evals[BASE_TRACE][1][i]);
accum.accumulate(i, numerator * *denom_inverse);
}
}
Loading

0 comments on commit 7af1f0d

Please sign in to comment.