Skip to content

Commit

Permalink
Witgen: Handle intermediate polynomials (#2007)
Browse files Browse the repository at this point in the history
Completes a task in #2009.

With this PR, we no longer rely on
`Analyzed::identities_with_inlined_intermediate_polynomials()`, which
can produce exponentially large expressions in some cases. Instead,
intermediate polynomials are evaluated on demand and cached. Thibaut's
example from #1995 speeds up massively with this PR.
  • Loading branch information
georgwiese authored Nov 19, 2024
1 parent a017bf7 commit efbcd1f
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 67 deletions.
53 changes: 40 additions & 13 deletions executor/src/witgen/expression_evaluator.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::marker::PhantomData;
use std::collections::BTreeMap;

use powdr_ast::analyzed::{
AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression as Expression,
AlgebraicUnaryOperation, AlgebraicUnaryOperator, Challenge,
AlgebraicUnaryOperation, AlgebraicUnaryOperator, Challenge, PolyID, PolynomialType,
};

use powdr_number::FieldElement;
Expand All @@ -23,30 +23,57 @@ pub trait SymbolicVariables<T> {
}
}

pub struct ExpressionEvaluator<T, SV> {
pub struct ExpressionEvaluator<'a, T, SV> {
variables: SV,
marker: PhantomData<T>,
intermediate_definitions: &'a BTreeMap<PolyID, &'a Expression<T>>,
/// Maps intermediate polynomial IDs to their evaluation. Updated throughout the lifetime of the
/// ExpressionEvaluator.
intermediates_cache: BTreeMap<PolyID, AffineResult<AlgebraicVariable<'a>, T>>,
}

impl<T, SV> ExpressionEvaluator<T, SV>
impl<'a, T, SV> ExpressionEvaluator<'a, T, SV>
where
SV: SymbolicVariables<T>,
T: FieldElement,
{
pub fn new(variables: SV) -> Self {
pub fn new(
variables: SV,
intermediate_definitions: &'a BTreeMap<PolyID, &'a Expression<T>>,
) -> Self {
Self {
variables,
marker: PhantomData,
intermediate_definitions,
intermediates_cache: Default::default(),
}
}

/// Tries to evaluate the expression to an affine expression in the witness polynomials
/// or publics, taking their current values into account.
/// Might update its cache of evaluations of intermediate polynomials.
/// @returns an expression affine in the witness polynomials or publics.
pub fn evaluate<'a>(&self, expr: &'a Expression<T>) -> AffineResult<AlgebraicVariable<'a>, T> {
pub fn evaluate(&mut self, expr: &'a Expression<T>) -> AffineResult<AlgebraicVariable<'a>, T> {
// @TODO if we iterate on processing the constraints in the same row,
// we could store the simplified values.
match expr {
Expression::Reference(poly) => self.variables.value(AlgebraicVariable::Column(poly)),
Expression::Reference(poly) => match poly.poly_id.ptype {
PolynomialType::Committed | PolynomialType::Constant => {
self.variables.value(AlgebraicVariable::Column(poly))
}
PolynomialType::Intermediate => {
let value = self.intermediates_cache.get(&poly.poly_id).cloned();
match value {
Some(v) => v,
None => {
let definition =
self.intermediate_definitions.get(&poly.poly_id).unwrap();
let result = self.evaluate(definition);
self.intermediates_cache
.insert(poly.poly_id, result.clone());
result
}
}
}
},
Expression::PublicReference(public) => {
self.variables.value(AlgebraicVariable::Public(public))
}
Expand All @@ -61,8 +88,8 @@ where
}
}

fn evaluate_binary_operation<'a>(
&self,
fn evaluate_binary_operation(
&mut self,
left: &'a Expression<T>,
op: &AlgebraicBinaryOperator,
right: &'a Expression<T>,
Expand Down Expand Up @@ -128,8 +155,8 @@ where
}
}

fn evaluate_unary_operation<'a>(
&self,
fn evaluate_unary_operation(
&mut self,
op: &AlgebraicUnaryOperator,
expr: &'a Expression<T>,
) -> AffineResult<AlgebraicVariable<'a>, T> {
Expand Down
37 changes: 25 additions & 12 deletions executor/src/witgen/global_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use super::range_constraints::RangeConstraint;
use super::symbolic_evaluator::SymbolicEvaluator;
use super::util::try_to_simple_poly;
use super::{Constraint, FixedData};
use powdr_ast::analyzed::AlgebraicExpression;

/// Trait that provides a range constraint on a symbolic variable if given by ID.
pub trait RangeConstraintSet<K, T: FieldElement> {
Expand Down Expand Up @@ -154,6 +155,7 @@ pub fn set_global_constraints<'a, T: FieldElement>(
let mut range_constraint_multiplicities = BTreeMap::new();
for identity in identities.into_iter() {
let remove = propagate_constraints(
&fixed_data.intermediate_definitions,
&mut known_constraints,
&mut range_constraint_multiplicities,
identity,
Expand Down Expand Up @@ -244,20 +246,25 @@ fn process_fixed_column<T: FieldElement>(fixed: &[T]) -> Option<(RangeConstraint
/// If the returned flag is true, the identity can be removed, because it contains
/// no further information than the range constraint.
fn propagate_constraints<T: FieldElement>(
intermediate_definitions: &BTreeMap<PolyID, &AlgebraicExpression<T>>,
known_constraints: &mut BTreeMap<PolyID, RangeConstraint<T>>,
range_constraint_multiplicities: &mut BTreeMap<PolyID, PhantomRangeConstraintTarget>,
identity: &Identity<T>,
full_span: &BTreeSet<PolyID>,
) -> bool {
match identity {
Identity::Polynomial(identity) => {
if let Some(p) = is_binary_constraint(&identity.expression) {
if let Some(p) = is_binary_constraint(intermediate_definitions, &identity.expression) {
assert!(known_constraints
.insert(p, RangeConstraint::from_max_bit(0))
.is_none());
true
} else {
for (p, c) in try_transfer_constraints(&identity.expression, known_constraints) {
for (p, c) in try_transfer_constraints(
intermediate_definitions,
&identity.expression,
known_constraints,
) {
known_constraints
.entry(p)
.and_modify(|existing| *existing = existing.conjunction(&c))
Expand Down Expand Up @@ -325,7 +332,10 @@ fn propagate_constraints<T: FieldElement>(
}

/// Tries to find "X * (1 - X) = 0"
fn is_binary_constraint<T: FieldElement>(expr: &Expression<T>) -> Option<PolyID> {
fn is_binary_constraint<T: FieldElement>(
intermediate_definitions: &BTreeMap<PolyID, &AlgebraicExpression<T>>,
expr: &Expression<T>,
) -> Option<PolyID> {
// TODO Write a proper pattern matching engine.
if let Expression::BinaryOperation(AlgebraicBinaryOperation {
left,
Expand All @@ -335,7 +345,7 @@ fn is_binary_constraint<T: FieldElement>(expr: &Expression<T>) -> Option<PolyID>
{
if let Expression::Number(n) = right.as_ref() {
if n.is_zero() {
return is_binary_constraint(left.as_ref());
return is_binary_constraint(intermediate_definitions, left.as_ref());
}
}
} else if let Expression::BinaryOperation(AlgebraicBinaryOperation {
Expand All @@ -344,12 +354,9 @@ fn is_binary_constraint<T: FieldElement>(expr: &Expression<T>) -> Option<PolyID>
right,
}) = expr
{
let symbolic_ev = SymbolicEvaluator;
let left_root = ExpressionEvaluator::new(symbolic_ev.clone())
.evaluate(left)
.ok()
.and_then(|l| l.solve().ok())?;
let right_root = ExpressionEvaluator::new(symbolic_ev)
let mut evaluator = ExpressionEvaluator::new(SymbolicEvaluator, intermediate_definitions);
let left_root = evaluator.evaluate(left).ok().and_then(|l| l.solve().ok())?;
let right_root = evaluator
.evaluate(right)
.ok()
.and_then(|r| r.solve().ok())?;
Expand All @@ -373,15 +380,18 @@ fn is_binary_constraint<T: FieldElement>(expr: &Expression<T>) -> Option<PolyID>

/// Tries to transfer constraints in a linear expression.
fn try_transfer_constraints<T: FieldElement>(
intermediate_definitions: &BTreeMap<PolyID, &AlgebraicExpression<T>>,
expr: &Expression<T>,
known_constraints: &BTreeMap<PolyID, RangeConstraint<T>>,
) -> Vec<(PolyID, RangeConstraint<T>)> {
if expr.contains_next_ref() {
return vec![];
}

let symbolic_ev = SymbolicEvaluator;
let Some(aff_expr) = ExpressionEvaluator::new(symbolic_ev).evaluate(expr).ok() else {
let Some(aff_expr) = ExpressionEvaluator::new(SymbolicEvaluator, intermediate_definitions)
.evaluate(expr)
.ok()
else {
return vec![];
};

Expand Down Expand Up @@ -539,6 +549,7 @@ namespace Global(2**20);
let mut range_constraint_multiplicities = BTreeMap::new();
for identity in &analyzed.identities {
propagate_constraints(
&BTreeMap::new(),
&mut known_constraints,
&mut range_constraint_multiplicities,
identity,
Expand Down Expand Up @@ -630,6 +641,7 @@ namespace Global(2**20);
let mut range_constraint_multiplicities = BTreeMap::new();
for identity in &analyzed.identities {
propagate_constraints(
&BTreeMap::new(),
&mut known_constraints,
&mut range_constraint_multiplicities,
identity,
Expand Down Expand Up @@ -706,6 +718,7 @@ namespace Global(1024);
let mut range_constraint_multiplicities = BTreeMap::new();
assert_eq!(analyzed.identities.len(), 1);
let removed = propagate_constraints(
&BTreeMap::new(),
&mut known_constraints,
&mut range_constraint_multiplicities,
analyzed.identities.first().unwrap(),
Expand Down
4 changes: 2 additions & 2 deletions executor/src/witgen/identity_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback<T>> IdentityProcessor<'a, 'b,
fn process_polynomial_identity(
&self,
identity: &'a PolynomialIdentity<T>,
rows: &RowPair<T>,
rows: &RowPair<'_, 'a, T>,
) -> EvalResult<'a, T> {
match rows.evaluate(&identity.expression) {
Err(incomplete_cause) => Ok(EvalValue::incomplete(incomplete_cause)),
Expand Down Expand Up @@ -242,7 +242,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback<T>> IdentityProcessor<'a, 'b,
fn handle_left_selector(
&self,
left_selector: &'a Expression<T>,
rows: &RowPair<T>,
rows: &RowPair<'_, 'a, T>,
) -> Option<EvalValue<AlgebraicVariable<'a>, T>> {
let value = match rows.evaluate(left_selector) {
Err(incomplete_cause) => return Some(EvalValue::incomplete(incomplete_cause)),
Expand Down
20 changes: 10 additions & 10 deletions executor/src/witgen/machines/fixed_lookup_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ impl<'a, T: FieldElement> FixedLookup<'a, T> {
&mut self,
mutable_state: &'b mut MutableState<'a, 'b, T, Q>,
identity_id: u64,
rows: &RowPair<'_, '_, T>,
rows: &RowPair<'_, 'a, T>,
left: &[AffineExpression<AlgebraicVariable<'a>, T>],
mut right: Peekable<impl Iterator<Item = &'a AlgebraicReference>>,
) -> EvalResult<'a, T> {
Expand Down Expand Up @@ -261,12 +261,12 @@ impl<'a, T: FieldElement> FixedLookup<'a, T> {
Ok(result)
}

fn process_range_check<'b>(
fn process_range_check(
&self,
rows: &RowPair<'_, '_, T>,
lhs: &AffineExpression<AlgebraicVariable<'b>, T>,
rhs: AlgebraicVariable<'b>,
) -> EvalResult<'b, T> {
rows: &RowPair<'_, 'a, T>,
lhs: &AffineExpression<AlgebraicVariable<'a>, T>,
rhs: AlgebraicVariable<'a>,
) -> EvalResult<'a, T> {
// Use AffineExpression::solve_with_range_constraints to transfer range constraints
// from the rhs to the lhs.
let equation = lhs.clone() - AffineExpression::from_variable_id(rhs);
Expand Down Expand Up @@ -438,13 +438,13 @@ impl<'a, T: FieldElement> Machine<'a, T> for FixedLookup<'a, T> {
/// (used for fixed columns).
/// This is useful in order to transfer range constraints from fixed columns to
/// witness columns (see [FixedLookup::process_range_check]).
pub struct UnifiedRangeConstraints<'a, T: FieldElement> {
witness_constraints: &'a RowPair<'a, 'a, T>,
global_constraints: &'a GlobalConstraints<T>,
pub struct UnifiedRangeConstraints<'a, 'b, T: FieldElement> {
witness_constraints: &'b RowPair<'b, 'a, T>,
global_constraints: &'b GlobalConstraints<T>,
}

impl<'a, T: FieldElement> RangeConstraintSet<AlgebraicVariable<'a>, T>
for UnifiedRangeConstraints<'_, T>
for UnifiedRangeConstraints<'_, '_, T>
{
fn range_constraint(&self, var: AlgebraicVariable<'a>) -> Option<RangeConstraint<T>> {
let poly = match var {
Expand Down
18 changes: 13 additions & 5 deletions executor/src/witgen/machines/machine_extractor.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::collections::{BTreeMap, BTreeSet, HashSet};

use itertools::Itertools;
use powdr_ast::analyzed::LookupIdentity;
use powdr_ast::analyzed::PermutationIdentity;
use powdr_ast::analyzed::PhantomLookupIdentity;
use powdr_ast::analyzed::PhantomPermutationIdentity;
use powdr_ast::analyzed::{LookupIdentity, PolynomialType};

use super::block_machine::BlockMachine;
use super::double_sorted_witness_machine_16::DoubleSortedWitnesses16;
Expand Down Expand Up @@ -305,10 +305,18 @@ impl<'a, T: FieldElement> MachineExtractor<'a, T> {
}
}

fn refs_in_expression(&self, expr: &'a Expression<T>) -> impl Iterator<Item = PolyID> + '_ {
Box::new(expr.all_children().filter_map(move |e| match e {
Expression::Reference(p) => Some(p.poly_id),
_ => None,
fn refs_in_expression(&self, expr: &'a Expression<T>) -> Box<dyn Iterator<Item = PolyID> + '_> {
Box::new(expr.all_children().flat_map(move |e| match e {
Expression::Reference(p) => match p.poly_id.ptype {
PolynomialType::Committed | PolynomialType::Constant => {
Box::new(std::iter::once(p.poly_id))
}
// For intermediate polynomials, recursively extract the references in the expression.
PolynomialType::Intermediate => self.refs_in_expression(
self.fixed.intermediate_definitions.get(&p.poly_id).unwrap(),
),
},
_ => Box::new(std::iter::empty()),
}))
}
}
Expand Down
23 changes: 15 additions & 8 deletions executor/src/witgen/machines/sorted_witness_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,17 @@ fn check_identity<T: FieldElement>(
}

// Check for A' - A in the LHS
let key_column = check_constraint(left.expressions.first().unwrap())?;
let key_column = check_constraint(fixed_data, left.expressions.first().unwrap())?;

let not_last = &left.selector;
let positive = right.expressions.first().unwrap();

// TODO this could be rather slow. We should check the code for identity instead
// of evaluating it.
for row in 0..(degree as usize) {
let ev = ExpressionEvaluator::new(FixedEvaluator::new(fixed_data, row, degree));
let fixed_evaluator = FixedEvaluator::new(fixed_data, row, degree);
let mut ev =
ExpressionEvaluator::new(fixed_evaluator, &fixed_data.intermediate_definitions);
let degree = degree as usize;
let nl = ev.evaluate(not_last).ok()?.constant_value()?;
if (row == degree - 1 && !nl.is_zero()) || (row < degree - 1 && !nl.is_one()) {
Expand All @@ -148,12 +150,17 @@ fn check_identity<T: FieldElement>(

/// Checks that the identity has a constraint of the form `a' - a` as the first expression
/// on the left hand side and returns the ID of the witness column.
fn check_constraint<T: FieldElement>(constraint: &Expression<T>) -> Option<PolyID> {
let symbolic_ev = SymbolicEvaluator;
let sort_constraint = match ExpressionEvaluator::new(symbolic_ev).evaluate(constraint) {
Ok(c) => c,
Err(_) => return None,
};
fn check_constraint<T: FieldElement>(
fixed: &FixedData<T>,
constraint: &Expression<T>,
) -> Option<PolyID> {
let sort_constraint =
match ExpressionEvaluator::new(SymbolicEvaluator, &fixed.intermediate_definitions)
.evaluate(constraint)
{
Ok(c) => c,
Err(_) => return None,
};
let mut coeff = sort_constraint.nonzero_coefficients();
let first = coeff
.next()
Expand Down
Loading

0 comments on commit efbcd1f

Please sign in to comment.