From efbcd1ffe216b402891bc1aa780f039295600d1c Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Tue, 19 Nov 2024 11:19:31 +0100 Subject: [PATCH] Witgen: Handle intermediate polynomials (#2007) Completes a task in #2009. With this PR, we no longer rely on `Analyzed::identities_with_inlined_intermediate_polynomials()`, which can produce exponentially large expressions in some cases. Instead, intermediate polynomials are evaluated on demand and cached. Thibaut's example from #1995 speeds up massively with this PR. --- executor/src/witgen/expression_evaluator.rs | 53 ++++++++++++++----- executor/src/witgen/global_constraints.rs | 37 ++++++++----- executor/src/witgen/identity_processor.rs | 4 +- .../witgen/machines/fixed_lookup_machine.rs | 20 +++---- .../src/witgen/machines/machine_extractor.rs | 18 +++++-- .../witgen/machines/sorted_witness_machine.rs | 23 +++++--- executor/src/witgen/mod.rs | 15 +++++- executor/src/witgen/query_processor.rs | 6 +-- executor/src/witgen/rows.rs | 12 +++-- .../src/witgen/symbolic_witness_evaluator.rs | 22 ++++---- 10 files changed, 143 insertions(+), 67 deletions(-) diff --git a/executor/src/witgen/expression_evaluator.rs b/executor/src/witgen/expression_evaluator.rs index 70aba8cc0b..acee2d3052 100644 --- a/executor/src/witgen/expression_evaluator.rs +++ b/executor/src/witgen/expression_evaluator.rs @@ -1,8 +1,8 @@ -use std::marker::PhantomData; +use std::collections::BTreeMap; use powdr_ast::analyzed::{ AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression as Expression, - AlgebraicUnaryOperation, AlgebraicUnaryOperator, Challenge, + AlgebraicUnaryOperation, AlgebraicUnaryOperator, Challenge, PolyID, PolynomialType, }; use powdr_number::FieldElement; @@ -23,30 +23,57 @@ pub trait SymbolicVariables { } } -pub struct ExpressionEvaluator { +pub struct ExpressionEvaluator<'a, T, SV> { variables: SV, - marker: PhantomData, + intermediate_definitions: &'a BTreeMap>, + /// Maps intermediate polynomial IDs to their evaluation. Updated throughout the lifetime of the + /// ExpressionEvaluator. + intermediates_cache: BTreeMap, T>>, } -impl ExpressionEvaluator +impl<'a, T, SV> ExpressionEvaluator<'a, T, SV> where SV: SymbolicVariables, T: FieldElement, { - pub fn new(variables: SV) -> Self { + pub fn new( + variables: SV, + intermediate_definitions: &'a BTreeMap>, + ) -> Self { Self { variables, - marker: PhantomData, + intermediate_definitions, + intermediates_cache: Default::default(), } } + /// Tries to evaluate the expression to an affine expression in the witness polynomials /// or publics, taking their current values into account. + /// Might update its cache of evaluations of intermediate polynomials. /// @returns an expression affine in the witness polynomials or publics. - pub fn evaluate<'a>(&self, expr: &'a Expression) -> AffineResult, T> { + pub fn evaluate(&mut self, expr: &'a Expression) -> AffineResult, T> { // @TODO if we iterate on processing the constraints in the same row, // we could store the simplified values. match expr { - Expression::Reference(poly) => self.variables.value(AlgebraicVariable::Column(poly)), + Expression::Reference(poly) => match poly.poly_id.ptype { + PolynomialType::Committed | PolynomialType::Constant => { + self.variables.value(AlgebraicVariable::Column(poly)) + } + PolynomialType::Intermediate => { + let value = self.intermediates_cache.get(&poly.poly_id).cloned(); + match value { + Some(v) => v, + None => { + let definition = + self.intermediate_definitions.get(&poly.poly_id).unwrap(); + let result = self.evaluate(definition); + self.intermediates_cache + .insert(poly.poly_id, result.clone()); + result + } + } + } + }, Expression::PublicReference(public) => { self.variables.value(AlgebraicVariable::Public(public)) } @@ -61,8 +88,8 @@ where } } - fn evaluate_binary_operation<'a>( - &self, + fn evaluate_binary_operation( + &mut self, left: &'a Expression, op: &AlgebraicBinaryOperator, right: &'a Expression, @@ -128,8 +155,8 @@ where } } - fn evaluate_unary_operation<'a>( - &self, + fn evaluate_unary_operation( + &mut self, op: &AlgebraicUnaryOperator, expr: &'a Expression, ) -> AffineResult, T> { diff --git a/executor/src/witgen/global_constraints.rs b/executor/src/witgen/global_constraints.rs index 78d95bd543..d9d115bf27 100644 --- a/executor/src/witgen/global_constraints.rs +++ b/executor/src/witgen/global_constraints.rs @@ -21,6 +21,7 @@ use super::range_constraints::RangeConstraint; use super::symbolic_evaluator::SymbolicEvaluator; use super::util::try_to_simple_poly; use super::{Constraint, FixedData}; +use powdr_ast::analyzed::AlgebraicExpression; /// Trait that provides a range constraint on a symbolic variable if given by ID. pub trait RangeConstraintSet { @@ -154,6 +155,7 @@ pub fn set_global_constraints<'a, T: FieldElement>( let mut range_constraint_multiplicities = BTreeMap::new(); for identity in identities.into_iter() { let remove = propagate_constraints( + &fixed_data.intermediate_definitions, &mut known_constraints, &mut range_constraint_multiplicities, identity, @@ -244,6 +246,7 @@ fn process_fixed_column(fixed: &[T]) -> Option<(RangeConstraint /// If the returned flag is true, the identity can be removed, because it contains /// no further information than the range constraint. fn propagate_constraints( + intermediate_definitions: &BTreeMap>, known_constraints: &mut BTreeMap>, range_constraint_multiplicities: &mut BTreeMap, identity: &Identity, @@ -251,13 +254,17 @@ fn propagate_constraints( ) -> bool { match identity { Identity::Polynomial(identity) => { - if let Some(p) = is_binary_constraint(&identity.expression) { + if let Some(p) = is_binary_constraint(intermediate_definitions, &identity.expression) { assert!(known_constraints .insert(p, RangeConstraint::from_max_bit(0)) .is_none()); true } else { - for (p, c) in try_transfer_constraints(&identity.expression, known_constraints) { + for (p, c) in try_transfer_constraints( + intermediate_definitions, + &identity.expression, + known_constraints, + ) { known_constraints .entry(p) .and_modify(|existing| *existing = existing.conjunction(&c)) @@ -325,7 +332,10 @@ fn propagate_constraints( } /// Tries to find "X * (1 - X) = 0" -fn is_binary_constraint(expr: &Expression) -> Option { +fn is_binary_constraint( + intermediate_definitions: &BTreeMap>, + expr: &Expression, +) -> Option { // TODO Write a proper pattern matching engine. if let Expression::BinaryOperation(AlgebraicBinaryOperation { left, @@ -335,7 +345,7 @@ fn is_binary_constraint(expr: &Expression) -> Option { if let Expression::Number(n) = right.as_ref() { if n.is_zero() { - return is_binary_constraint(left.as_ref()); + return is_binary_constraint(intermediate_definitions, left.as_ref()); } } } else if let Expression::BinaryOperation(AlgebraicBinaryOperation { @@ -344,12 +354,9 @@ fn is_binary_constraint(expr: &Expression) -> Option right, }) = expr { - let symbolic_ev = SymbolicEvaluator; - let left_root = ExpressionEvaluator::new(symbolic_ev.clone()) - .evaluate(left) - .ok() - .and_then(|l| l.solve().ok())?; - let right_root = ExpressionEvaluator::new(symbolic_ev) + let mut evaluator = ExpressionEvaluator::new(SymbolicEvaluator, intermediate_definitions); + let left_root = evaluator.evaluate(left).ok().and_then(|l| l.solve().ok())?; + let right_root = evaluator .evaluate(right) .ok() .and_then(|r| r.solve().ok())?; @@ -373,6 +380,7 @@ fn is_binary_constraint(expr: &Expression) -> Option /// Tries to transfer constraints in a linear expression. fn try_transfer_constraints( + intermediate_definitions: &BTreeMap>, expr: &Expression, known_constraints: &BTreeMap>, ) -> Vec<(PolyID, RangeConstraint)> { @@ -380,8 +388,10 @@ fn try_transfer_constraints( return vec![]; } - let symbolic_ev = SymbolicEvaluator; - let Some(aff_expr) = ExpressionEvaluator::new(symbolic_ev).evaluate(expr).ok() else { + let Some(aff_expr) = ExpressionEvaluator::new(SymbolicEvaluator, intermediate_definitions) + .evaluate(expr) + .ok() + else { return vec![]; }; @@ -539,6 +549,7 @@ namespace Global(2**20); let mut range_constraint_multiplicities = BTreeMap::new(); for identity in &analyzed.identities { propagate_constraints( + &BTreeMap::new(), &mut known_constraints, &mut range_constraint_multiplicities, identity, @@ -630,6 +641,7 @@ namespace Global(2**20); let mut range_constraint_multiplicities = BTreeMap::new(); for identity in &analyzed.identities { propagate_constraints( + &BTreeMap::new(), &mut known_constraints, &mut range_constraint_multiplicities, identity, @@ -706,6 +718,7 @@ namespace Global(1024); let mut range_constraint_multiplicities = BTreeMap::new(); assert_eq!(analyzed.identities.len(), 1); let removed = propagate_constraints( + &BTreeMap::new(), &mut known_constraints, &mut range_constraint_multiplicities, analyzed.identities.first().unwrap(), diff --git a/executor/src/witgen/identity_processor.rs b/executor/src/witgen/identity_processor.rs index e5a76ac831..4f170f40b2 100644 --- a/executor/src/witgen/identity_processor.rs +++ b/executor/src/witgen/identity_processor.rs @@ -173,7 +173,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> IdentityProcessor<'a, 'b, fn process_polynomial_identity( &self, identity: &'a PolynomialIdentity, - rows: &RowPair, + rows: &RowPair<'_, 'a, T>, ) -> EvalResult<'a, T> { match rows.evaluate(&identity.expression) { Err(incomplete_cause) => Ok(EvalValue::incomplete(incomplete_cause)), @@ -242,7 +242,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> IdentityProcessor<'a, 'b, fn handle_left_selector( &self, left_selector: &'a Expression, - rows: &RowPair, + rows: &RowPair<'_, 'a, T>, ) -> Option, T>> { let value = match rows.evaluate(left_selector) { Err(incomplete_cause) => return Some(EvalValue::incomplete(incomplete_cause)), diff --git a/executor/src/witgen/machines/fixed_lookup_machine.rs b/executor/src/witgen/machines/fixed_lookup_machine.rs index 95520c2602..54b9c440ed 100644 --- a/executor/src/witgen/machines/fixed_lookup_machine.rs +++ b/executor/src/witgen/machines/fixed_lookup_machine.rs @@ -201,7 +201,7 @@ impl<'a, T: FieldElement> FixedLookup<'a, T> { &mut self, mutable_state: &'b mut MutableState<'a, 'b, T, Q>, identity_id: u64, - rows: &RowPair<'_, '_, T>, + rows: &RowPair<'_, 'a, T>, left: &[AffineExpression, T>], mut right: Peekable>, ) -> EvalResult<'a, T> { @@ -261,12 +261,12 @@ impl<'a, T: FieldElement> FixedLookup<'a, T> { Ok(result) } - fn process_range_check<'b>( + fn process_range_check( &self, - rows: &RowPair<'_, '_, T>, - lhs: &AffineExpression, T>, - rhs: AlgebraicVariable<'b>, - ) -> EvalResult<'b, T> { + rows: &RowPair<'_, 'a, T>, + lhs: &AffineExpression, T>, + rhs: AlgebraicVariable<'a>, + ) -> EvalResult<'a, T> { // Use AffineExpression::solve_with_range_constraints to transfer range constraints // from the rhs to the lhs. let equation = lhs.clone() - AffineExpression::from_variable_id(rhs); @@ -438,13 +438,13 @@ impl<'a, T: FieldElement> Machine<'a, T> for FixedLookup<'a, T> { /// (used for fixed columns). /// This is useful in order to transfer range constraints from fixed columns to /// witness columns (see [FixedLookup::process_range_check]). -pub struct UnifiedRangeConstraints<'a, T: FieldElement> { - witness_constraints: &'a RowPair<'a, 'a, T>, - global_constraints: &'a GlobalConstraints, +pub struct UnifiedRangeConstraints<'a, 'b, T: FieldElement> { + witness_constraints: &'b RowPair<'b, 'a, T>, + global_constraints: &'b GlobalConstraints, } impl<'a, T: FieldElement> RangeConstraintSet, T> - for UnifiedRangeConstraints<'_, T> + for UnifiedRangeConstraints<'_, '_, T> { fn range_constraint(&self, var: AlgebraicVariable<'a>) -> Option> { let poly = match var { diff --git a/executor/src/witgen/machines/machine_extractor.rs b/executor/src/witgen/machines/machine_extractor.rs index 165fbc1f02..7a69e10ef5 100644 --- a/executor/src/witgen/machines/machine_extractor.rs +++ b/executor/src/witgen/machines/machine_extractor.rs @@ -1,10 +1,10 @@ use std::collections::{BTreeMap, BTreeSet, HashSet}; use itertools::Itertools; -use powdr_ast::analyzed::LookupIdentity; use powdr_ast::analyzed::PermutationIdentity; use powdr_ast::analyzed::PhantomLookupIdentity; use powdr_ast::analyzed::PhantomPermutationIdentity; +use powdr_ast::analyzed::{LookupIdentity, PolynomialType}; use super::block_machine::BlockMachine; use super::double_sorted_witness_machine_16::DoubleSortedWitnesses16; @@ -305,10 +305,18 @@ impl<'a, T: FieldElement> MachineExtractor<'a, T> { } } - fn refs_in_expression(&self, expr: &'a Expression) -> impl Iterator + '_ { - Box::new(expr.all_children().filter_map(move |e| match e { - Expression::Reference(p) => Some(p.poly_id), - _ => None, + fn refs_in_expression(&self, expr: &'a Expression) -> Box + '_> { + Box::new(expr.all_children().flat_map(move |e| match e { + Expression::Reference(p) => match p.poly_id.ptype { + PolynomialType::Committed | PolynomialType::Constant => { + Box::new(std::iter::once(p.poly_id)) + } + // For intermediate polynomials, recursively extract the references in the expression. + PolynomialType::Intermediate => self.refs_in_expression( + self.fixed.intermediate_definitions.get(&p.poly_id).unwrap(), + ), + }, + _ => Box::new(std::iter::empty()), })) } } diff --git a/executor/src/witgen/machines/sorted_witness_machine.rs b/executor/src/witgen/machines/sorted_witness_machine.rs index ec1eca5021..cd46002089 100644 --- a/executor/src/witgen/machines/sorted_witness_machine.rs +++ b/executor/src/witgen/machines/sorted_witness_machine.rs @@ -124,7 +124,7 @@ fn check_identity( } // Check for A' - A in the LHS - let key_column = check_constraint(left.expressions.first().unwrap())?; + let key_column = check_constraint(fixed_data, left.expressions.first().unwrap())?; let not_last = &left.selector; let positive = right.expressions.first().unwrap(); @@ -132,7 +132,9 @@ fn check_identity( // TODO this could be rather slow. We should check the code for identity instead // of evaluating it. for row in 0..(degree as usize) { - let ev = ExpressionEvaluator::new(FixedEvaluator::new(fixed_data, row, degree)); + let fixed_evaluator = FixedEvaluator::new(fixed_data, row, degree); + let mut ev = + ExpressionEvaluator::new(fixed_evaluator, &fixed_data.intermediate_definitions); let degree = degree as usize; let nl = ev.evaluate(not_last).ok()?.constant_value()?; if (row == degree - 1 && !nl.is_zero()) || (row < degree - 1 && !nl.is_one()) { @@ -148,12 +150,17 @@ fn check_identity( /// Checks that the identity has a constraint of the form `a' - a` as the first expression /// on the left hand side and returns the ID of the witness column. -fn check_constraint(constraint: &Expression) -> Option { - let symbolic_ev = SymbolicEvaluator; - let sort_constraint = match ExpressionEvaluator::new(symbolic_ev).evaluate(constraint) { - Ok(c) => c, - Err(_) => return None, - }; +fn check_constraint( + fixed: &FixedData, + constraint: &Expression, +) -> Option { + let sort_constraint = + match ExpressionEvaluator::new(SymbolicEvaluator, &fixed.intermediate_definitions) + .evaluate(constraint) + { + Ok(c) => c, + Err(_) => return None, + }; let mut coeff = sort_constraint.nonzero_coefficients(); let first = coeff .next() diff --git a/executor/src/witgen/mod.rs b/executor/src/witgen/mod.rs index 3b2448e05b..314656981a 100644 --- a/executor/src/witgen/mod.rs +++ b/executor/src/witgen/mod.rs @@ -194,7 +194,8 @@ impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> { ); let identities = self .analyzed - .identities_with_inlined_intermediate_polynomials() + .identities + .clone() .into_iter() .filter(|identity| { let discard = identity.expr_any(|expr| { @@ -370,6 +371,7 @@ pub struct FixedData<'a, T: FieldElement> { column_by_name: HashMap, challenges: BTreeMap, global_range_constraints: GlobalConstraints, + intermediate_definitions: BTreeMap>, } impl<'a, T: FieldElement> FixedData<'a, T> { @@ -385,6 +387,16 @@ impl<'a, T: FieldElement> FixedData<'a, T> { .map(|(name, values)| (name.clone(), values)) .collect::>(); + let intermediate_definitions = analyzed + .intermediate_polys_in_source_order() + .flat_map(|(symbol, definitions)| { + symbol + .array_elements() + .zip_eq(definitions) + .map(|((_, poly_id), def)| (poly_id, def)) + }) + .collect(); + let witness_cols = WitnessColumnMap::from(analyzed.committed_polys_in_source_order().flat_map( |(poly, value)| { @@ -438,6 +450,7 @@ impl<'a, T: FieldElement> FixedData<'a, T> { .collect(), challenges, global_range_constraints, + intermediate_definitions, } } diff --git a/executor/src/witgen/query_processor.rs b/executor/src/witgen/query_processor.rs index a1c76489b8..2020c158b0 100644 --- a/executor/src/witgen/query_processor.rs +++ b/executor/src/witgen/query_processor.rs @@ -78,7 +78,7 @@ impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback> /// @returns None if the value for that column is already known. pub fn process_query( &mut self, - rows: &RowPair, + rows: &RowPair<'_, 'a, T>, poly_id: &PolyID, ) -> Option> { let column = &self.fixed_data.witness_cols[poly_id]; @@ -94,7 +94,7 @@ impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback> &mut self, query: &'a Expression, poly: &'a AlgebraicReference, - rows: &RowPair, + rows: &RowPair<'_, 'a, T>, ) -> EvalResult<'a, T> { let query_str = match self.interpolate_query(query, rows) { Ok(query) => query, @@ -131,7 +131,7 @@ impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback> fn interpolate_query( &mut self, query: &'a Expression, - rows: &RowPair, + rows: &RowPair<'_, 'a, T>, ) -> Result { let arguments = vec![Arc::new(Value::Integer(BigInt::from(u64::from( rows.current_row_index, diff --git a/executor/src/witgen/rows.rs b/executor/src/witgen/rows.rs index 4473b8967f..e7a5b95f19 100644 --- a/executor/src/witgen/rows.rs +++ b/executor/src/witgen/rows.rs @@ -462,14 +462,18 @@ impl<'row, 'a, T: FieldElement> RowPair<'row, 'a, T> { /// Tries to evaluate the expression to an expression affine in the witness polynomials, /// taking current values of polynomials into account. /// @returns an expression affine in the witness polynomials - pub fn evaluate<'b>(&self, expr: &'b Expression) -> AffineResult, T> { - ExpressionEvaluator::new(SymbolicWitnessEvaluator::new( + pub fn evaluate(&self, expr: &'a Expression) -> AffineResult, T> { + let variables = SymbolicWitnessEvaluator::new( self.fixed_data, self.current_row_index.into(), self, self.size, - )) - .evaluate(expr) + ); + // Note that because we instantiate a fresh evaluator here, we don't benefit from caching + // of intermediate values across calls of `RowPair::evaluate`. In practice, we only call + // it many times for the same RowPair though. + ExpressionEvaluator::new(variables, &self.fixed_data.intermediate_definitions) + .evaluate(expr) } } diff --git a/executor/src/witgen/symbolic_witness_evaluator.rs b/executor/src/witgen/symbolic_witness_evaluator.rs index 097d6bcce3..7724eef122 100644 --- a/executor/src/witgen/symbolic_witness_evaluator.rs +++ b/executor/src/witgen/symbolic_witness_evaluator.rs @@ -1,4 +1,4 @@ -use powdr_ast::analyzed::Challenge; +use powdr_ast::analyzed::{Challenge, PolynomialType}; use powdr_number::{DegreeType, FieldElement}; use super::{ @@ -54,14 +54,18 @@ where match var { AlgebraicVariable::Column(poly) => { // TODO arrays - if poly.is_witness() { - self.witness_access.value(var) - } else { - // Constant polynomial (or something else) - let values = self.fixed_data.fixed_cols[&poly.poly_id].values(self.size); - let row = if poly.next { self.row + 1 } else { self.row } - % (values.len() as DegreeType); - Ok(values[row as usize].into()) + match poly.poly_id.ptype { + PolynomialType::Committed => self.witness_access.value(var), + PolynomialType::Constant => { + // Constant polynomial (or something else) + let values = self.fixed_data.fixed_cols[&poly.poly_id].values(self.size); + let row = if poly.next { self.row + 1 } else { self.row } + % (values.len() as DegreeType); + Ok(values[row as usize].into()) + } + PolynomialType::Intermediate => unreachable!( + "ExpressionEvaluator should have resolved intermediate polynomials" + ), } } AlgebraicVariable::Public(_) => self.witness_access.value(var),