Skip to content

Commit

Permalink
Optimize for 1s and 0s in CompactPolynomial::bind
Browse files Browse the repository at this point in the history
  • Loading branch information
moodlezoup committed Jan 16, 2025
1 parent 9a99cf0 commit 883dba6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 10 deletions.
4 changes: 3 additions & 1 deletion jolt-core/src/msm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use rayon::prelude::*;

pub(crate) mod icicle;
use crate::field::JoltField;
use crate::poly::{dense_mlpoly::DensePolynomial, multilinear_polynomial::MultilinearPolynomial};
#[cfg(feature = "icicle")]
use crate::poly::dense_mlpoly::DensePolynomial;
use crate::poly::multilinear_polynomial::MultilinearPolynomial;
use crate::utils::errors::ProofVerifyError;
use crate::utils::math::Math;
pub use icicle::*;
Expand Down
56 changes: 47 additions & 9 deletions jolt-core/src/poly/compact_polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,49 @@ pub trait SmallScalar: Copy + Integer + Sync {
/// the multiplication, so the other operand should probably have an additional R^2
/// factor (see `JoltField::montgomery_r2`).
fn field_mul<F: JoltField>(&self, n: F) -> F;
/// Converts a small scalar into a (potentially Montgomery form) `JoltField` type
fn to_field<F: JoltField>(self) -> F;
}

impl SmallScalar for u8 {
#[inline]
fn field_mul<F: JoltField>(&self, n: F) -> F {
n.mul_u64_unchecked(*self as u64)
}
#[inline]
fn to_field<F: JoltField>(self) -> F {
F::from_u8(self)
}
}
impl SmallScalar for u16 {
#[inline]
fn field_mul<F: JoltField>(&self, n: F) -> F {
n.mul_u64_unchecked(*self as u64)
}
#[inline]
fn to_field<F: JoltField>(self) -> F {
F::from_u16(self)
}
}
impl SmallScalar for u32 {
#[inline]
fn field_mul<F: JoltField>(&self, n: F) -> F {
n.mul_u64_unchecked(*self as u64)
}
#[inline]
fn to_field<F: JoltField>(self) -> F {
F::from_u32(self)
}
}
impl SmallScalar for u64 {
#[inline]
fn field_mul<F: JoltField>(&self, n: F) -> F {
n.mul_u64_unchecked(*self)
}
#[inline]
fn to_field<F: JoltField>(self) -> F {
F::from_u64(self)
}
}
impl SmallScalar for i64 {
#[inline]
Expand All @@ -50,6 +68,10 @@ impl SmallScalar for i64 {
n.mul_u64_unchecked(*self as u64)
}
}
#[inline]
fn to_field<F: JoltField>(self) -> F {
F::from_i64(self)
}
}

#[derive(Default, Debug, PartialEq)]
Expand Down Expand Up @@ -128,23 +150,29 @@ impl<T: SmallScalar, F: JoltField> PolynomialBinding<F> for CompactPolynomial<T,
!self.bound_coeffs.is_empty()
}

// TODO(moodlezoup): Optimize for 0s and 1s
#[tracing::instrument(skip_all, name = "CompactPolynomial::bind")]
fn bind(&mut self, r: F, order: BindingOrder) {
let n = self.len() / 2;
if self.is_bound() {
match order {
BindingOrder::LowToHigh => {
for i in 0..n {
self.bound_coeffs[i] = self.bound_coeffs[2 * i]
+ r * (self.bound_coeffs[2 * i + 1] - self.bound_coeffs[2 * i]);
if self.bound_coeffs[2 * i + 1] == self.bound_coeffs[2 * i] {
self.bound_coeffs[i] = self.bound_coeffs[2 * i];
} else {
self.bound_coeffs[i] = self.bound_coeffs[2 * i]
+ r * (self.bound_coeffs[2 * i + 1] - self.bound_coeffs[2 * i]);
}
}
}
BindingOrder::HighToLow => {
let (left, right) = self.bound_coeffs.split_at_mut(n);
left.iter_mut().zip(right.iter()).for_each(|(a, b)| {
*a += r * (*b - *a);
});
left.iter_mut()
.zip(right.iter())
.filter(|(a, b)| a != b)
.for_each(|(a, b)| {
*a += r * (*b - *a);
});
}
}
} else {
Expand All @@ -154,8 +182,12 @@ impl<T: SmallScalar, F: JoltField> PolynomialBinding<F> for CompactPolynomial<T,
BindingOrder::LowToHigh => {
self.bound_coeffs = (0..n)
.map(|i| {
self.coeffs[2 * i].field_mul(one_minus_r_r2)
+ self.coeffs[2 * i + 1].field_mul(r_r2)
if self.coeffs[2 * i] == self.coeffs[2 * i + 1] {
self.coeffs[2 * i].to_field()
} else {
self.coeffs[2 * i].field_mul(one_minus_r_r2)
+ self.coeffs[2 * i + 1].field_mul(r_r2)
}
})
.collect();
}
Expand All @@ -164,7 +196,13 @@ impl<T: SmallScalar, F: JoltField> PolynomialBinding<F> for CompactPolynomial<T,
self.bound_coeffs = left
.iter()
.zip(right.iter())
.map(|(&a, &b)| a.field_mul(one_minus_r_r2) + b.field_mul(r_r2))
.map(|(&a, &b)| {
if a == b {
a.to_field()
} else {
a.field_mul(one_minus_r_r2) + b.field_mul(r_r2)
}
})
.collect();
}
}
Expand Down

0 comments on commit 883dba6

Please sign in to comment.