diff --git a/executor/src/witgen/jit/cell.rs b/executor/src/witgen/jit/cell.rs new file mode 100644 index 0000000000..090e277048 --- /dev/null +++ b/executor/src/witgen/jit/cell.rs @@ -0,0 +1,59 @@ +use std::{ + fmt::{self, Display, Formatter}, + hash::{Hash, Hasher}, +}; + +use powdr_ast::analyzed::AlgebraicReference; + +/// The identifier of a witness cell in the trace table. +/// The `row_offset` is relative to a certain "zero row" defined +/// by the component that uses this data structure. +#[derive(Debug, Clone, Eq)] +pub struct Cell { + /// Name of the column, used only for display purposes. + pub column_name: String, + pub id: u64, + pub row_offset: i32, +} + +impl Hash for Cell { + fn hash(&self, state: &mut H) { + self.id.hash(state); + self.row_offset.hash(state); + } +} + +impl PartialEq for Cell { + fn eq(&self, other: &Self) -> bool { + self.id == other.id && self.row_offset == other.row_offset + } +} + +impl Ord for Cell { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + (self.id, self.row_offset).cmp(&(other.id, other.row_offset)) + } +} + +impl PartialOrd for Cell { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Cell { + pub fn from_reference(r: &AlgebraicReference, row_offset: i32) -> Self { + assert!(r.is_witness()); + Self { + column_name: r.name.clone(), + id: r.poly_id.id, + row_offset: r.next as i32 + row_offset, + } + } +} + +impl Display for Cell { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}[{}]", self.column_name, self.row_offset) + } +} diff --git a/executor/src/witgen/jit/mod.rs b/executor/src/witgen/jit/mod.rs index a6d9dabf39..c3151f1638 100644 --- a/executor/src/witgen/jit/mod.rs +++ b/executor/src/witgen/jit/mod.rs @@ -1,2 +1,4 @@ -mod affine_symbolic_expression; +pub(crate) mod affine_symbolic_expression; +mod cell; mod symbolic_expression; +mod witgen_inference; diff --git a/executor/src/witgen/jit/witgen_inference.rs b/executor/src/witgen/jit/witgen_inference.rs new file mode 100644 index 0000000000..c395654dd4 --- /dev/null +++ b/executor/src/witgen/jit/witgen_inference.rs @@ -0,0 +1,289 @@ +use std::{ + collections::{BTreeSet, HashMap, HashSet}, + iter::once, +}; + +use itertools::Itertools; +use powdr_ast::{ + analyzed::{ + AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression as Expression, + AlgebraicReference, AlgebraicUnaryOperation, AlgebraicUnaryOperator, Identity, + LookupIdentity, PhantomLookupIdentity, PolyID, PolynomialIdentity, PolynomialType, + SelectedExpressions, + }, + indent, + parsed::visitor::AllChildren, +}; +use powdr_number::FieldElement; + +use crate::witgen::global_constraints::RangeConstraintSet; + +use super::{ + super::{machines::MachineParts, range_constraints::RangeConstraint, FixedData}, + affine_symbolic_expression::{AffineSymbolicExpression, Effect}, + cell::Cell, + symbolic_expression::SymbolicExpression, +}; + +/// This component can generate code that solves identities. +/// It needs a driver that tells it which identities to process on which rows. +pub struct WitgenInference<'a, T: FieldElement> { + fixed_data: &'a FixedData<'a, T>, + range_constraints: HashMap>, + known_cells: HashSet, + code: Vec, // TODO make this a proper expression +} + +impl<'a, T: FieldElement> WitgenInference<'a, T> { + pub fn new( + fixed_data: &'a FixedData<'a, T>, + known_cells: impl IntoIterator, + ) -> Self { + Self { + fixed_data, + range_constraints: Default::default(), + known_cells: known_cells.into_iter().collect(), + code: Default::default(), + } + } + + fn cell_at_row(&self, id: u64, row_offset: i32) -> Cell { + let poly_id = PolyID { + id, + ptype: PolynomialType::Committed, + }; + Cell { + column_name: self.fixed_data.column_name(&poly_id).to_string(), + id, + row_offset, + } + } + + fn known_cells(&self) -> &HashSet { + &self.known_cells + } + + pub fn process_identity(&mut self, id: &Identity, row_offset: i32) { + let effects = match id { + Identity::Polynomial(PolynomialIdentity { expression, .. }) => { + self.process_polynomial_identity(expression, row_offset) + } + Identity::Lookup(LookupIdentity { + id, + source: _, + left, + right, + }) + | Identity::PhantomLookup(PhantomLookupIdentity { + id, + source: _, + left, + right, + multiplicity: _, + }) => { + // TODO multiplicity? + self.process_lookup(*id, left, right, row_offset) + } + _ => { + // TODO + vec![] + } + }; + self.ingest_effects(effects); + } + + fn process_polynomial_identity( + &self, + expression: &'a Expression, + offset: i32, + ) -> Vec> { + if let Some(r) = self.evaluate(expression, offset) { + // TODO remove unwrap + r.solve(self).unwrap() + } else { + vec![] + } + } + + fn process_lookup( + &self, + lookup_id: u64, + left: &SelectedExpressions, + right: &SelectedExpressions, + offset: i32, + ) -> Vec> { + // TODO: In the future, call the 'mutable state' to check if the + // lookup can always be answered. + + // If the RHS is fully fixed columns... + if right.expressions.iter().all(|e| match e { + Expression::Reference(r) => r.is_fixed(), + Expression::Number(_) => true, + _ => false, + }) { + // and the selector is known to be 1 and all except one expression is known on the LHS. + if self + .evaluate(&left.selector, offset) + .map(|s| s.is_known_one()) + == Some(true) + { + if let Some(inputs) = left + .expressions + .iter() + .map(|e| self.evaluate(e, offset)) + .collect::>>() + { + if inputs.iter().filter(|i| i.is_known()).count() == inputs.len() - 1 { + let mut var_decl = String::new(); + let mut output_var = String::new(); + let query = inputs + .iter() + .enumerate() + .map(|(i, e)| { + if e.is_known() { + format!("LookupCell::Input(&({e}))") + } else { + let var_name = format!("lookup_{lookup_id}_{i}"); + output_var = var_name.clone(); + var_decl.push_str(&format!( + "let mut {var_name}: FieldElement = Default::default();" + )); + format!("LookupCell::Output(&mut {var_name})") + } + }) + .format(", "); + let machine_call = format!( + "assert!(process_lookup(mutable_state, {lookup_id}, &mut [{query}]));" + ); + // TODO range constraints? + let output_expr = inputs.iter().find(|i| !i.is_known()).unwrap(); + return once(Effect::Code(var_decl)) + .chain(once(Effect::Code(machine_call))) + .chain( + (output_expr + - &KnownValue::from_known_local_var(&output_var).into()) + .solve(self), + ) + .collect(); + } + } + } + } + vec![] + } + + fn ingest_effects(&mut self, effects: Vec>) { + for e in effects { + match e { + Effect::Assignment(cell, assignment) => { + // TODO also use raneg constraint? + self.known_cells.insert(cell.clone()); + self.code.push(assignment); + } + Effect::RangeConstraint(cell, rc) => { + self.add_range_constraint(cell, rc); + } + Effect::Code(code) => { + self.code.push(code); + } + } + } + } + + fn add_range_constraint(&mut self, cell: Cell, rc: RangeConstraint) { + let rc = self + .range_constraint(cell.clone()) + .map_or(rc.clone(), |existing_rc| existing_rc.conjunction(&rc)); + // TODO if the conjuntion results in a single value, make the cell known. + self.range_constraints.insert(cell, rc); + } + + fn evaluate( + &self, + expr: &Expression, + offset: i32, + ) -> Option> { + Some(match expr { + Expression::Reference(r) => { + if r.is_fixed() { + todo!() + // let mut row = self.latch_row as i64 + offset as i64; + // while row < 0 { + // row += self.block_size as i64; + // } + // // TODO at some point we should check that all of the fixed columns are periodic. + // // TODO We can only do this for block machines. + // // For dynamic machines, fixed columns are "known but symbolic" + // let v = self.fixed_data.fixed_cols[&r.poly_id].values_max_size()[row as usize]; + // EvalResult::from_number(v) + } else { + let cell = Cell::from_reference(r, offset); + // If a cell is known and has a compile-time constant value, + // that value is stored in the range constraints. + if let Some(v) = self + .range_constraint(cell.clone()) + .and_then(|rc| rc.try_to_single_value()) + { + AffineSymbolicExpression::from_number(v) + } else if self.known_cells.contains(&cell) { + AffineSymbolicExpression::from_known_variable(cell) + } else { + AffineSymbolicExpression::from_unknown_variable(cell) + } + } + } + Expression::PublicReference(_) => return None, // TODO + Expression::Challenge(_) => return None, // TODO + Expression::Number(n) => AffineSymbolicExpression::from_number(*n), + Expression::BinaryOperation(op) => self.evaulate_binary_operation(op, offset)?, + Expression::UnaryOperation(op) => self.evaluate_unary_operation(op, offset)?, + }) + } + + fn evaulate_binary_operation( + &self, + op: &AlgebraicBinaryOperation, + offset: i32, + ) -> Option> { + let left = self.evaluate(&op.left, offset)?; + let right = self.evaluate(&op.right, offset)?; + match op.op { + AlgebraicBinaryOperator::Add => Some(&left + &right), + AlgebraicBinaryOperator::Sub => Some(&left - &right), + AlgebraicBinaryOperator::Mul => left.try_mul(&right), + _ => todo!(), + } + } + + fn evaluate_unary_operation( + &self, + op: &AlgebraicUnaryOperation, + offset: i32, + ) -> Option> { + let expr = self.evaluate(&op.expr, offset)?; + match op.op { + AlgebraicUnaryOperator::Minus => Some(-&expr), + } + } +} + +impl RangeConstraintSet for WitgenInference<'_, T> { + // TODO would be nice to use &Cell, but this leads to lifetime trouble + // in the solve() function. + fn range_constraint(&self, cell: Cell) -> Option> { + self.fixed_data + .global_range_constraints + .range_constraint(&AlgebraicReference { + name: Default::default(), + poly_id: PolyID { + id: cell.id, + ptype: PolynomialType::Committed, + }, + next: false, + }) + .iter() + .chain(self.range_constraints.get(&cell)) + .cloned() + .reduce(|gc, rc| gc.conjunction(&rc)) + } +} diff --git a/executor/src/witgen/range_constraints.rs b/executor/src/witgen/range_constraints.rs index 8e6d9748a8..15b99e9eca 100644 --- a/executor/src/witgen/range_constraints.rs +++ b/executor/src/witgen/range_constraints.rs @@ -157,6 +157,14 @@ impl RangeConstraint { mask: mask.unwrap_or_else(|| Self::from_range(min, max).mask), } } + + pub fn try_to_single_value(&self) -> Option { + if self.min == self.max { + Some(self.min) + } else { + None + } + } } /// The number of elements in an (inclusive) min/max range.