Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/jit_solving' into witgen_inference
Browse files Browse the repository at this point in the history
  • Loading branch information
chriseth committed Dec 10, 2024
2 parents e3e1fdb + ff6a2f7 commit aa91ce3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 29 deletions.
7 changes: 2 additions & 5 deletions executor/src/witgen/jit/affine_symbolic_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,16 +332,13 @@ impl<T: FieldElement, V: Clone + Ord> Add for &AffineSymbolicExpression<T, V> {

fn add(self, rhs: Self) -> Self::Output {
let mut coefficients = self.coefficients.clone();
for (var, coeff) in rhs.coefficients.iter() {
for (var, coeff) in &rhs.coefficients {
coefficients
.entry(var.clone())
.and_modify(|f| *f = &*f + coeff)
.or_insert_with(|| coeff.clone());
}
let coefficients = coefficients
.into_iter()
.filter(|(_, f)| !f.is_known_zero())
.collect();
coefficients.retain(|_, f| !f.is_known_zero());
let offset = &self.offset + &rhs.offset;
AffineSymbolicExpression {
coefficients,
Expand Down
53 changes: 29 additions & 24 deletions executor/src/witgen/jit/symbolic_expression.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
fmt::{self, Display, Formatter},
ops::{Add, BitAnd, BitOr, Mul, Neg},
rc::Rc,
};

use powdr_number::FieldElement;
Expand All @@ -17,12 +18,12 @@ pub enum SymbolicExpression<T: FieldElement, V> {
/// A symbolic value known at run-time, referencing either a cell or a local variable.
Variable(V, Option<RangeConstraint<T>>),
BinaryOperation(
Box<Self>,
Rc<Self>,
BinaryOperator,
Box<Self>,
Rc<Self>,
Option<RangeConstraint<T>>,
),
UnaryOperation(UnaryOperator, Box<Self>, Option<RangeConstraint<T>>),
UnaryOperation(UnaryOperator, Rc<Self>, Option<RangeConstraint<T>>),
}

#[derive(Debug, Clone)]
Expand All @@ -44,6 +45,10 @@ pub enum UnaryOperator {
}

impl<T: FieldElement, V> SymbolicExpression<T, V> {
pub fn from_var(name: V) -> Self {
SymbolicExpression::Variable(name, None)
}

pub fn is_known_zero(&self) -> bool {
self.try_to_number().map_or(false, |n| n.is_zero())
}
Expand Down Expand Up @@ -116,12 +121,6 @@ impl Display for UnaryOperator {
}
}

impl<T: FieldElement, V> SymbolicExpression<T, V> {
pub fn from_var(name: V) -> Self {
SymbolicExpression::Variable(name, None)
}
}

impl<T: FieldElement, V> From<T> for SymbolicExpression<T, V> {
fn from(n: T) -> Self {
SymbolicExpression::Concrete(n)
Expand All @@ -143,9 +142,9 @@ impl<T: FieldElement, V: Clone> Add for &SymbolicExpression<T, V> {
SymbolicExpression::Concrete(*a + *b)
}
_ => SymbolicExpression::BinaryOperation(
Box::new(self.clone()),
Rc::new(self.clone()),
BinaryOperator::Add,
Box::new(rhs.clone()),
Rc::new(rhs.clone()),
self.range_constraint()
.zip(rhs.range_constraint())
.map(|(a, b)| a.combine_sum(&b)),
Expand All @@ -167,10 +166,12 @@ impl<T: FieldElement, V: Clone> Neg for &SymbolicExpression<T, V> {
fn neg(self) -> Self::Output {
match self {
SymbolicExpression::Concrete(n) => SymbolicExpression::Concrete(-*n),
SymbolicExpression::UnaryOperation(UnaryOperator::Neg, expr, _) => *expr.clone(),
SymbolicExpression::UnaryOperation(UnaryOperator::Neg, expr, _) => {
expr.as_ref().clone()
}
_ => SymbolicExpression::UnaryOperation(
UnaryOperator::Neg,
Box::new(self.clone()),
Rc::new(self.clone()),
self.range_constraint().map(|rc| rc.multiple(-T::from(1))),
),
}
Expand Down Expand Up @@ -202,9 +203,9 @@ impl<T: FieldElement, V: Clone> Mul for &SymbolicExpression<T, V> {
-self
} else {
SymbolicExpression::BinaryOperation(
Box::new(self.clone()),
Rc::new(self.clone()),
BinaryOperator::Mul,
Box::new(rhs.clone()),
Rc::new(rhs.clone()),
None,
)
}
Expand Down Expand Up @@ -234,9 +235,9 @@ impl<T: FieldElement, V: Clone> SymbolicExpression<T, V> {
} else {
// TODO other simplifications like `-x / -y => x / y`, `-x / concrete => x / -concrete`, etc.
SymbolicExpression::BinaryOperation(
Box::new(self.clone()),
Rc::new(self.clone()),
BinaryOperator::Div,
Box::new(rhs.clone()),
Rc::new(rhs.clone()),
None,
)
}
Expand All @@ -248,9 +249,9 @@ impl<T: FieldElement, V: Clone> SymbolicExpression<T, V> {
self.clone()
} else {
SymbolicExpression::BinaryOperation(
Box::new(self.clone()),
Rc::new(self.clone()),
BinaryOperator::IntegerDiv,
Box::new(rhs.clone()),
Rc::new(rhs.clone()),
None,
)
}
Expand All @@ -267,9 +268,9 @@ impl<T: FieldElement, V: Clone> BitAnd for &SymbolicExpression<T, V> {
SymbolicExpression::Concrete(T::from(0))
} else {
SymbolicExpression::BinaryOperation(
Box::new(self.clone()),
Rc::new(self.clone()),
BinaryOperator::BitAnd,
Box::new(rhs.clone()),
Rc::new(rhs.clone()),
self.range_constraint()
.zip(rhs.range_constraint())
.map(|(a, b)| a.conjunction(&b)),
Expand All @@ -292,13 +293,17 @@ impl<T: FieldElement, V: Clone> BitOr for &SymbolicExpression<T, V> {
fn bitor(self, rhs: Self) -> Self::Output {
if let (SymbolicExpression::Concrete(a), SymbolicExpression::Concrete(b)) = (self, rhs) {
let v = a.to_integer() | b.to_integer();
assert!(v <= T::modulus());
assert!(v < T::modulus());
SymbolicExpression::Concrete(T::from(v))
} else if self.is_known_zero() {
rhs.clone()
} else if rhs.is_known_zero() {
self.clone()
} else {
SymbolicExpression::BinaryOperation(
Box::new(self.clone()),
Rc::new(self.clone()),
BinaryOperator::BitOr,
Box::new(rhs.clone()),
Rc::new(rhs.clone()),
None,
)
}
Expand Down

0 comments on commit aa91ce3

Please sign in to comment.