Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into provide_completed_ide…
Browse files Browse the repository at this point in the history
…ntities
  • Loading branch information
chriseth committed Jan 15, 2025
2 parents aedb50d + 7cbf05f commit 4477ea7
Show file tree
Hide file tree
Showing 27 changed files with 610 additions and 209 deletions.
4 changes: 1 addition & 3 deletions backend/src/stwo/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,5 @@ where
MC::H: DeserializeOwned + Serialize,
{
pub stark_proof: StarkProof<MC::H>,
pub constant_col_log_sizes: Vec<u32>,
pub witness_col_log_sizes: Vec<u32>,
pub machine_log_sizes: Vec<u32>,
pub machine_log_sizes: BTreeMap<String, u32>,
}
59 changes: 32 additions & 27 deletions backend/src/stwo/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use powdr_number::FieldElement;
use serde::de::DeserializeOwned;
use serde::ser::Serialize;
use std::collections::BTreeMap;
use std::iter::repeat;
use std::marker::PhantomData;
use std::sync::Arc;
use std::{fmt, io};
Expand All @@ -23,7 +24,7 @@ use stwo_prover::constraint_framework::{
};

use stwo_prover::core::air::{Component, ComponentProver};
use stwo_prover::core::backend::{Backend, BackendForChannel, Column};
use stwo_prover::core::backend::{Backend, BackendForChannel};
use stwo_prover::core::channel::{Channel, MerkleChannel};
use stwo_prover::core::fields::m31::{BaseField, M31};
use stwo_prover::core::fields::qm31::SecureField;
Expand Down Expand Up @@ -245,7 +246,7 @@ where

//The preprocessed columns needs to be indexed in the whole execution instead of each machine, so we need to keep track of the offset
let mut constant_cols_offset_acc = 0;
let mut machine_log_sizes = Vec::new();
let mut machine_log_sizes = BTreeMap::new();

let mut constant_cols = Vec::new();

Expand Down Expand Up @@ -293,7 +294,7 @@ where
);
components.push(component);

machine_log_sizes.push(machine_length.ilog2());
machine_log_sizes.insert(machine.clone(), machine_length.ilog2());

constant_cols_offset_acc +=
pil.constant_count() + get_constant_with_next_list(pil).len();
Expand All @@ -316,16 +317,6 @@ where
.flatten()
.collect();

let constant_col_log_sizes = constant_cols
.iter()
.map(|eval| eval.len().ilog2())
.collect::<Vec<_>>();

let witness_col_log_sizes = witness_by_machine
.iter()
.map(|eval| eval.len().ilog2())
.collect::<Vec<_>>();

let twiddles_max_degree = B::precompute_twiddles(
CanonicCoset::new(domain_degree_range.max.ilog2() + 1 + FRI_LOG_BLOWUP as u32)
.circle_domain()
Expand Down Expand Up @@ -366,8 +357,6 @@ where

let proof: Proof<MC> = Proof {
stark_proof,
constant_col_log_sizes,
witness_col_log_sizes,
machine_log_sizes,
};
Ok(bincode::serialize(&proof).unwrap())
Expand All @@ -394,20 +383,35 @@ where

let mut constant_cols_offset_acc = 0;

let mut constant_col_log_sizes = vec![];
let mut witness_col_log_sizes = vec![];

let mut components = self
.split
.iter()
.zip_eq(proof.machine_log_sizes.iter())
.map(|((_, pil), &machine_size)| {
constant_cols_offset_acc += pil.constant_count();

constant_cols_offset_acc += get_constant_with_next_list(pil).len();
PowdrComponent::new(
tree_span_provider,
PowdrEval::new((*pil).clone(), constant_cols_offset_acc, machine_size),
(SecureField::zero(), None),
)
})
.map(
|((machine_name, pil), (proof_machine_name, &machine_log_size))| {
assert_eq!(machine_name, proof_machine_name);
let machine_component = PowdrComponent::new(
tree_span_provider,
PowdrEval::new((*pil).clone(), constant_cols_offset_acc, machine_log_size),
(SecureField::zero(), None),
);

constant_cols_offset_acc += pil.constant_count();

constant_cols_offset_acc += get_constant_with_next_list(pil).len();

constant_col_log_sizes.extend(
repeat(machine_log_size)
.take(pil.constant_count() + get_constant_with_next_list(pil).len()),
);
witness_col_log_sizes
.extend(repeat(machine_log_size).take(pil.commitment_count()));
machine_component
},
)
.collect::<Vec<_>>();

let mut components_slice: Vec<&dyn Component> = components
Expand All @@ -419,12 +423,13 @@ where

commitment_scheme.commit(
proof.stark_proof.commitments[PREPROCESSED_TRACE_IDX],
&proof.constant_col_log_sizes,
&constant_col_log_sizes,
verifier_channel,
);

commitment_scheme.commit(
proof.stark_proof.commitments[ORIGINAL_TRACE_IDX],
&proof.witness_col_log_sizes,
&witness_col_log_sizes,
verifier_channel,
);

Expand Down
16 changes: 16 additions & 0 deletions executor/src/witgen/bus_accumulator/extension_field.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use std::{
collections::BTreeMap,
iter::Sum,
ops::{Add, Mul, Sub},
};

use num_traits::{One, Zero};

pub trait ExtensionField<T>:
Add<Output = Self> + Sub<Output = Self> + Mul<T, Output = Self> + Zero + One + Copy + Sum
{
fn get_challenge(challenges: &BTreeMap<u64, T>, index: u64) -> Self;
fn size() -> usize;
fn inverse(self) -> Self;
fn to_vec(self) -> Vec<T>;
}
18 changes: 18 additions & 0 deletions executor/src/witgen/bus_accumulator/fp2.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::{
collections::BTreeMap,
iter::Sum,
ops::{Add, Div, Mul, Sub},
};

use num_traits::{One, Zero};
use powdr_number::FieldElement;

use super::extension_field::ExtensionField;

/// An implementation of Fp2, analogous to `std/math/fp2.asm`.
/// An Fp2 element. The tuple (a, b) represents the polynomial a + b * X.
/// All computations are done modulo the irreducible polynomial X^2 - 11.
Expand Down Expand Up @@ -101,6 +104,21 @@ impl<T: FieldElement> Div for Fp2<T> {
}
}

impl<T: FieldElement> ExtensionField<T> for Fp2<T> {
fn get_challenge(challenges: &BTreeMap<u64, T>, index: u64) -> Self {
Fp2::new(challenges[&(index * 2 + 1)], challenges[&(index * 2 + 2)])
}
fn size() -> usize {
2
}
fn inverse(self) -> Self {
self.inverse()
}
fn to_vec(self) -> Vec<T> {
vec![self.0, self.1]
}
}

#[cfg(test)]
mod tests {
use powdr_number::GoldilocksField;
Expand Down
199 changes: 199 additions & 0 deletions executor/src/witgen/bus_accumulator/fp4.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
use std::{
collections::BTreeMap,
iter::Sum,
ops::{Add, Div, Mul, Sub},
};

use num_traits::{One, Zero};
use powdr_number::FieldElement;

use super::extension_field::ExtensionField;

/// An implementation of Fp4, analogous to `std/math/fp4.asm`.
/// An Fp4 element. The tuple (a, b, c, d) represents the polynomial a + b * X + c * X^2 + d * X^3.
/// All computations are done modulo the irreducible polynomial X^4 - 11.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Fp4<T>(pub T, pub T, pub T, pub T);

impl<T: FieldElement> Fp4<T> {
pub fn new(a: T, b: T, c: T, d: T) -> Self {
Fp4(a, b, c, d)
}
}

impl<T: FieldElement> Zero for Fp4<T> {
fn zero() -> Self {
Fp4(T::zero(), T::zero(), T::zero(), T::zero())
}

fn is_zero(&self) -> bool {
self.0.is_zero() && self.1.is_zero() && self.2.is_zero() && self.3.is_zero()
}
}

impl<T: FieldElement> One for Fp4<T> {
fn one() -> Self {
Fp4(T::one(), T::zero(), T::zero(), T::zero())
}

fn is_one(&self) -> bool {
self.0.is_one() && self.1.is_zero() && self.2.is_zero() && self.3.is_zero()
}
}

impl<T: FieldElement> From<T> for Fp4<T> {
fn from(a: T) -> Self {
Fp4(a, T::zero(), T::zero(), T::zero())
}
}

impl<T: FieldElement> Add for Fp4<T> {
type Output = Self;

fn add(self, other: Self) -> Self {
Fp4(
self.0 + other.0,
self.1 + other.1,
self.2 + other.2,
self.3 + other.3,
)
}
}

impl<T: FieldElement> Sum for Fp4<T> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::zero(), Add::add)
}
}

impl<T: FieldElement> Sub for Fp4<T> {
type Output = Self;

fn sub(self, other: Self) -> Self {
Fp4(
self.0 - other.0,
self.1 - other.1,
self.2 - other.2,
self.3 - other.3,
)
}
}

impl<T: FieldElement> Mul for Fp4<T> {
type Output = Self;

fn mul(self, other: Self) -> Self {
Fp4(
self.0 * other.0
+ T::from(11) * (self.1 * other.3 + self.2 * other.2 + self.3 * other.1),
self.0 * other.1
+ self.1 * other.0
+ T::from(11) * (self.2 * other.3 + self.3 * other.2),
self.0 * other.2 + self.1 * other.1 + self.2 * other.0 + T::from(11) * self.3 * other.3,
self.0 * other.3 + self.1 * other.2 + self.2 * other.1 + self.3 * other.0,
)
}
}

impl<T: FieldElement> Mul<T> for Fp4<T> {
type Output = Self;

fn mul(self, other: T) -> Self {
Fp4(
self.0 * other,
self.1 * other,
self.2 * other,
self.3 * other,
)
}
}

impl<T: FieldElement> Fp4<T> {
pub fn inverse(self) -> Self {
let b0 = self.0 * self.0 - T::from(11) * (self.1 * T::from(2) * self.3 - self.2 * self.2);
let b2 = T::from(2) * self.0 * self.2 - self.1 * self.1 - T::from(11) * self.3 * self.3;
let c = b0 * b0 - T::from(11) * b2 * b2;
let ic = T::from(1) / c;

let b0_ic = b0 * ic;
let b2_ic = b2 * ic;

Fp4(
self.0 * b0_ic - T::from(11) * self.2 * b2_ic,
-self.1 * b0_ic + T::from(11) * self.3 * b2_ic,
-self.0 * b2_ic + self.2 * b0_ic,
self.1 * b2_ic - self.3 * b0_ic,
)
}
}

impl<T: FieldElement> Div for Fp4<T> {
type Output = Self;

#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, other: Self) -> Self {
self * other.inverse()
}
}

impl<T: FieldElement> ExtensionField<T> for Fp4<T> {
fn get_challenge(challenges: &BTreeMap<u64, T>, index: u64) -> Self {
Fp4::new(
challenges[&(index * 4 + 1)],
challenges[&(index * 4 + 2)],
challenges[&(index * 4 + 3)],
challenges[&(index * 4 + 4)],
)
}
fn size() -> usize {
4
}
fn inverse(self) -> Self {
self.inverse()
}
fn to_vec(self) -> Vec<T> {
vec![self.0, self.1, self.2, self.3]
}
}

#[cfg(test)]
mod tests {
use powdr_number::GoldilocksField;

use super::*;

fn new(a: i64, b: i64, c: i64, d: i64) -> Fp4<GoldilocksField> {
Fp4(
GoldilocksField::from(a),
GoldilocksField::from(b),
GoldilocksField::from(c),
GoldilocksField::from(d),
)
}

fn from_base(x: i64) -> Fp4<GoldilocksField> {
Fp4::from(GoldilocksField::from(x))
}

#[test]
fn test_add() {
assert_eq!(from_base(0) + from_base(0), from_base(0));
assert_eq!(new(1, 2, 3, 4) + new(5, 6, 7, 8), new(6, 8, 10, 12));
}

#[test]
fn test_sub() {
assert_eq!(new(1, 2, 3, 4) - new(5, 6, 7, 8), new(-4, -4, -4, -4));
}

#[test]
fn test_mul() {
assert_eq!(new(1, 2, 3, 4) * new(5, 6, 7, 8), new(676, 588, 386, 60));
}

#[test]
fn test_inverse() {
let x = new(1, 2, 3, 4);
assert_eq!(x * x.inverse(), Fp4::one());
}
}
Loading

0 comments on commit 4477ea7

Please sign in to comment.