From 86cf0df559e5247fe8be4bf54a58164582b0de96 Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 20 Nov 2023 13:06:36 +0100 Subject: [PATCH] Replace Mapping enum alternative by Expression. --- analysis/src/macro_expansion.rs | 3 +- ast/src/analyzed/display.rs | 24 ++++++- ast/src/analyzed/mod.rs | 4 +- ast/src/analyzed/visitor.rs | 12 ++-- ast/src/parsed/display.rs | 17 +++-- ast/src/parsed/mod.rs | 4 +- ast/src/parsed/visitor.rs | 8 +-- executor/src/constant_evaluator/mod.rs | 18 ++---- executor/src/witgen/query_processor.rs | 17 ++--- parser/src/powdr.lalrpop | 2 +- pil_analyzer/src/evaluator.rs | 88 +++++++++++--------------- pil_analyzer/src/pil_analyzer.rs | 71 +++++++-------------- pilopt/src/lib.rs | 1 - 13 files changed, 118 insertions(+), 151 deletions(-) diff --git a/analysis/src/macro_expansion.rs b/analysis/src/macro_expansion.rs index 51d6faae29..6f90751889 100644 --- a/analysis/src/macro_expansion.rs +++ b/analysis/src/macro_expansion.rs @@ -97,8 +97,7 @@ where if let PilStatement::PolynomialConstantDefinition(_, _, f) | PilStatement::PolynomialCommitDeclaration(_, _, Some(f)) = &statement { - if let FunctionDefinition::Mapping(params, _) | FunctionDefinition::Query(params, _) = f - { + if let FunctionDefinition::Query(params, _) = f { assert!(self.shadowing_locals.is_empty()); self.shadowing_locals.extend(params.iter().cloned()); added_locals = true; diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index 6b46af5022..1218143de2 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -96,16 +96,34 @@ impl Display for Analyzed { impl Display for FunctionValueDefinition { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match self { - FunctionValueDefinition::Mapping(e) => write!(f, "(i) {{ {e} }}"), FunctionValueDefinition::Array(items) => { write!(f, " = {}", items.iter().format(" + ")) } - FunctionValueDefinition::Query(e) => write!(f, "(i) query {e}"), - FunctionValueDefinition::Expression(e) => write!(f, " = {e}"), + FunctionValueDefinition::Query(e) => format_outer_function(e, Some("query"), f), + FunctionValueDefinition::Expression(e) => format_outer_function(e, None, f), } } } +fn format_outer_function( + e: &Expression, + qualifier: Option<&str>, + f: &mut Formatter<'_>, +) -> Result { + let q = qualifier.map(|s| format!(" {s}")).unwrap_or_default(); + match e { + parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1 => { + let body = if q.is_empty() { + format!("{{ {} }}", lambda.body) + } else { + format!("{}", lambda.body) + }; + write!(f, "({}){q} {body}", lambda.params.iter().format(", "),) + } + _ => write!(f, " ={q} {e}"), + } +} + impl Display for RepeatedArray { fn fmt(&self, f: &mut Formatter<'_>) -> Result { if self.is_empty() { diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index a44a68425b..852ccd9401 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -283,8 +283,7 @@ impl Analyzed { self.definitions .values_mut() .for_each(|(_poly, definition)| match definition { - Some(FunctionValueDefinition::Mapping(e)) - | Some(FunctionValueDefinition::Query(e)) => e.post_visit_expressions_mut(f), + Some(FunctionValueDefinition::Query(e)) => e.post_visit_expressions_mut(f), Some(FunctionValueDefinition::Array(elements)) => elements .iter_mut() .flat_map(|e| e.pattern.iter_mut()) @@ -365,7 +364,6 @@ pub enum SymbolKind { #[derive(Debug, Clone)] pub enum FunctionValueDefinition { - Mapping(Expression), Array(Vec>), Query(Expression), Expression(Expression), diff --git a/ast/src/analyzed/visitor.rs b/ast/src/analyzed/visitor.rs index b9c8f79a71..dc17e79f0a 100644 --- a/ast/src/analyzed/visitor.rs +++ b/ast/src/analyzed/visitor.rs @@ -86,9 +86,9 @@ impl ExpressionVisitable> for FunctionValueDefinition { F: FnMut(&mut Expression) -> ControlFlow, { match self { - FunctionValueDefinition::Mapping(e) - | FunctionValueDefinition::Query(e) - | FunctionValueDefinition::Expression(e) => e.visit_expressions_mut(f, o), + FunctionValueDefinition::Query(e) | FunctionValueDefinition::Expression(e) => { + e.visit_expressions_mut(f, o) + } FunctionValueDefinition::Array(array) => array .iter_mut() .flat_map(|a| a.pattern.iter_mut()) @@ -101,9 +101,9 @@ impl ExpressionVisitable> for FunctionValueDefinition { F: FnMut(&Expression) -> ControlFlow, { match self { - FunctionValueDefinition::Mapping(e) - | FunctionValueDefinition::Query(e) - | FunctionValueDefinition::Expression(e) => e.visit_expressions(f, o), + FunctionValueDefinition::Query(e) | FunctionValueDefinition::Expression(e) => { + e.visit_expressions(f, o) + } FunctionValueDefinition::Array(array) => array .iter() .flat_map(|a| a.pattern().iter()) diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index 972b908550..39fcbbe62d 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -419,18 +419,23 @@ impl Display for ArrayExpression { impl Display for FunctionDefinition { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match self { - FunctionDefinition::Mapping(params, body) => { - write!(f, "({}) {{ {body} }}", params.join(", ")) - } FunctionDefinition::Array(array_expression) => { write!(f, " = {array_expression}") } FunctionDefinition::Query(params, value) => { write!(f, "({}) query {value}", params.join(", "),) } - FunctionDefinition::Expression(e) => { - write!(f, " = {e}") - } + FunctionDefinition::Expression(e) => match e { + Expression::LambdaExpression(lambda) if lambda.params.len() == 1 => { + write!( + f, + "({}) {{ {} }}", + lambda.params.iter().format(", "), + lambda.body + ) + } + _ => write!(f, " = {e}"), + }, } } } diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 8671d3a691..0fc3628bf9 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -251,13 +251,11 @@ pub enum MatchPattern { /// The definition of a function (excluding its name): #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] pub enum FunctionDefinition { - /// Parameter-value-mapping. - Mapping(Vec, Expression), /// Array expression. Array(ArrayExpression), /// Prover query. Query(Vec, Expression), - /// Expression, for intermediate polynomials + /// Generic expression Expression(Expression), } diff --git a/ast/src/parsed/visitor.rs b/ast/src/parsed/visitor.rs index c6071e8006..3f7412df84 100644 --- a/ast/src/parsed/visitor.rs +++ b/ast/src/parsed/visitor.rs @@ -291,9 +291,7 @@ impl ExpressionVisitable> for FunctionDefinition { F: FnMut(&mut Expression) -> ControlFlow, { match self { - FunctionDefinition::Query(_, e) | FunctionDefinition::Mapping(_, e) => { - e.visit_expressions_mut(f, o) - } + FunctionDefinition::Query(_, e) => e.visit_expressions_mut(f, o), FunctionDefinition::Array(ae) => ae.visit_expressions_mut(f, o), FunctionDefinition::Expression(e) => e.visit_expressions_mut(f, o), } @@ -304,9 +302,7 @@ impl ExpressionVisitable> for FunctionDefinition { F: FnMut(&Expression) -> ControlFlow, { match self { - FunctionDefinition::Query(_, e) | FunctionDefinition::Mapping(_, e) => { - e.visit_expressions(f, o) - } + FunctionDefinition::Query(_, e) => e.visit_expressions(f, o), FunctionDefinition::Array(ae) => ae.visit_expressions(f, o), FunctionDefinition::Expression(e) => e.visit_expressions(f, o), } diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index a96819421d..5d40cc770f 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, fmt::Display, rc::Rc}; use ast::analyzed::{Analyzed, FunctionValueDefinition}; use itertools::Itertools; use number::{DegreeType, FieldElement}; -use pil_analyzer::evaluator::{self, Closure, Custom, EvalError, SymbolLookup, Value}; +use pil_analyzer::evaluator::{self, Custom, EvalError, SymbolLookup, Value}; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; /// Generates the constant polynomial values for all constant polynomials @@ -36,10 +36,13 @@ fn generate_values( }; // TODO we should maybe pre-compute some symbols here. match body { - FunctionValueDefinition::Mapping(body) => (0..degree) + FunctionValueDefinition::Expression(e) => (0..degree) .into_par_iter() .map(|i| { - evaluator::evaluate_function_call(body, vec![i.into()], &symbols) + // We could try to avoid the first evaluation to be run for each iteration, + // but the data is not thread-safe. + let fun = evaluator::evaluate(e, &symbols).unwrap(); + evaluator::evaluate_function_call(fun, vec![Rc::new(T::from(i).into())], &symbols) .unwrap() .try_to_number() .unwrap() @@ -71,9 +74,6 @@ fn generate_values( values } FunctionValueDefinition::Query(_) => panic!("Query used for fixed column."), - FunctionValueDefinition::Expression(_) => { - panic!("Expression used for fixed column, only expected for intermediate polynomials") - } } } @@ -92,12 +92,6 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, FixedColumnRef<'a>> for Symbols<'a Some(FunctionValueDefinition::Expression(value)) => { evaluator::evaluate(value, self)? } - Some(FunctionValueDefinition::Mapping(body)) => (Closure { - parameter_count: 1, - body, - environment: vec![], - }) - .into(), Some(_) => Err(EvalError::Unsupported( "Cannot evaluate arrays and queries.".to_string(), ))?, diff --git a/executor/src/witgen/query_processor.rs b/executor/src/witgen/query_processor.rs index 25913453e3..c4bfa2aa28 100644 --- a/executor/src/witgen/query_processor.rs +++ b/executor/src/witgen/query_processor.rs @@ -74,16 +74,13 @@ impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback> query: &'a Expression, rows: &RowPair, ) -> Result { - let arguments = vec![rows.current_row_index.into()]; - evaluator::evaluate_function_call( - query, - arguments, - &Symbols { - fixed_data: self.fixed_data, - rows, - }, - ) - .map(|v| v.to_string()) + let arguments = vec![Rc::new(T::from(rows.current_row_index).into())]; + let symbols = Symbols { + fixed_data: self.fixed_data, + rows, + }; + let fun = evaluator::evaluate(query, &symbols)?; + evaluator::evaluate_function_call(fun, arguments, &symbols).map(|v| v.to_string()) } } diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index 308138ef0e..bdd9bc7ebc 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -102,7 +102,7 @@ PolynomialConstantDefinition: PilStatement = { } FunctionDefinition: FunctionDefinition = { - "(" ")" "{" "}" => FunctionDefinition::Mapping(<>), + "(" ")" "{" "}" => FunctionDefinition::Expression(Expression::LambdaExpression(LambdaExpression{params, body})), "=" => FunctionDefinition::Array(<>), } diff --git a/pil_analyzer/src/evaluator.rs b/pil_analyzer/src/evaluator.rs index 7307456575..7f1ac8b6a0 100644 --- a/pil_analyzer/src/evaluator.rs +++ b/pil_analyzer/src/evaluator.rs @@ -1,9 +1,9 @@ -use std::{collections::HashMap, fmt::Display, iter::repeat, rc::Rc}; +use std::{collections::HashMap, fmt::Display, rc::Rc}; use ast::{ analyzed::{Expression, FunctionValueDefinition, Reference, Symbol}, evaluate_binary_operation, evaluate_unary_operation, - parsed::{display::quote, FunctionCall, MatchArm, MatchPattern}, + parsed::{display::quote, FunctionCall, LambdaExpression, MatchArm, MatchPattern}, }; use itertools::Itertools; use number::FieldElement; @@ -16,6 +16,7 @@ pub fn evaluate_expression<'a, T: FieldElement>( evaluate(expr, &Definitions(definitions)) } +/// Evaluates an expression given a symbol lookup implementation pub fn evaluate<'a, T: FieldElement, C: Custom>( expr: &'a Expression, symbols: &impl SymbolLookup<'a, T, C>, @@ -23,21 +24,34 @@ pub fn evaluate<'a, T: FieldElement, C: Custom>( internal::evaluate(expr, &[], symbols) } -// TODO this function should be removed in the future, or at least it should -// check that `expr` actually evaluates to a Closure. +/// Evaluates a function call. pub fn evaluate_function_call<'a, T: FieldElement, C: Custom>( - expr: &'a Expression, - arguments: Vec, + function: Value<'a, T, C>, + arguments: Vec>>, symbols: &impl SymbolLookup<'a, T, C>, ) -> Result, EvalError> { - internal::evaluate( - expr, - &arguments - .into_iter() - .map(|x| Rc::new(Value::Number(x))) - .collect::>(), - symbols, - ) + match function { + Value::Closure(Closure { + lambda, + environment, + }) => { + if lambda.params.len() != arguments.len() { + Err(EvalError::TypeError(format!( + "Invalid function call: Supplied {} arguments to function that takes {} parameters.", + arguments.len(), + lambda.params.len()) + ))? + } + + let local_vars = arguments.into_iter().chain(environment).collect::>(); + + internal::evaluate(&lambda.body, &local_vars, symbols) + } + Value::Custom(value) => symbols.eval_function_application(value, &arguments), + e => Err(EvalError::TypeError(format!( + "Expected function but got {e}" + ))), + } } /// Evaluation errors. @@ -69,6 +83,12 @@ pub enum Value<'a, T, C> { Custom(C), } +impl<'a, T, C> From for Value<'a, T, C> { + fn from(value: T) -> Self { + Value::Number(value) + } +} + // TODO somehow, implementing TryFrom or TryInto did not work. impl<'a, T: FieldElement, C: Custom> Value<'a, T, C> { @@ -108,10 +128,7 @@ impl Display for NoCustom { #[derive(Clone)] pub struct Closure<'a, T, C> { - // TODO we could also store the names of the parameters (for printing) - // In order to do this, we would need to add a parameter name to Mapping, which might be a good idea anyway. - pub parameter_count: usize, - pub body: &'a Expression, + pub lambda: &'a LambdaExpression, pub environment: Vec>>, } @@ -125,12 +142,7 @@ impl<'a, T, C> PartialEq for Closure<'a, T, C> { impl<'a, T: Display, C> Display for Closure<'a, T, C> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "|{}| {}", - repeat('_').take(self.parameter_count).format(", "), - self.body - ) + write!(f, "{}", self.lambda) } } @@ -149,12 +161,6 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Definitions<'a, T> { Ok(match self.0.get(&name.to_string()) { Some((_, value)) => match value { Some(FunctionValueDefinition::Expression(value)) => evaluate(value, self)?, - Some(FunctionValueDefinition::Mapping(body)) => (Closure { - parameter_count: 1, - body, - environment: vec![], - }) - .into(), _ => Err(EvalError::Unsupported( "Cannot evaluate arrays and queries.".to_string(), ))?, @@ -235,8 +241,7 @@ mod internal { Expression::LambdaExpression(lambda) => { // TODO only copy the part of the environment that is actually referenced? (Closure { - parameter_count: lambda.params.len(), - body: lambda.body.as_ref(), + lambda, environment: locals.to_vec(), }) .into() @@ -266,24 +271,7 @@ mod internal { .iter() .map(|a| evaluate(a, locals, symbols).map(Rc::new)) .collect::, _>>()?; - match function { - Value::Closure(Closure { - parameter_count, - body, - environment, - }) => { - assert_eq!(parameter_count, arguments.len()); - - let local_vars = - arguments.into_iter().chain(environment).collect::>(); - - evaluate(body, &local_vars, symbols)? - } - Value::Custom(value) => symbols.eval_function_application(value, &arguments)?, - e => Err(EvalError::TypeError(format!( - "Expected function but got {e}" - )))?, - } + evaluate_function_call(function, arguments, symbols)? } Expression::MatchExpression(scrutinee, arms) => { let v = evaluate(scrutinee, locals, symbols)?; diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 67a691bdc9..2a0b36e9c8 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -223,37 +223,23 @@ impl PILAnalyzer { ); } Some(value) => { - match value { - parsed::Expression::LambdaExpression(parsed::LambdaExpression { - params, - body, - }) if params.len() == 1 => { - // Assigned value is a lambda expression with a single parameter => treat it as a fixed column. - self.handle_symbol_definition( - self.to_source_ref(start), - name, - None, - SymbolKind::Poly(PolynomialType::Constant), - Some(FunctionDefinition::Mapping(params, *body)), - ); - } - _ => { - let symbol_kind = if self.evaluate_expression(value.clone()).is_ok() { - // Value evaluates to a constant number => treat it as a constant - SymbolKind::Constant() - } else { - // Otherwise, treat it as "generic definition" - SymbolKind::Other() - }; - self.handle_symbol_definition( - self.to_source_ref(start), - name, - None, - symbol_kind, - Some(FunctionDefinition::Expression(value)), - ); - } - } + let symbol_kind = if matches!(&value, parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1) + { + SymbolKind::Poly(PolynomialType::Constant) + } else if self.evaluate_expression(value.clone()).is_ok() { + // Value evaluates to a constant number => treat it as a constant + SymbolKind::Constant() + } else { + // Otherwise, treat it as "generic definition" + SymbolKind::Other() + }; + self.handle_symbol_definition( + self.to_source_ref(start), + name, + None, + symbol_kind, + Some(FunctionDefinition::Expression(value)), + ); } } } @@ -384,26 +370,17 @@ impl PILAnalyzer { let value = value.map(|v| match v { FunctionDefinition::Expression(expr) => { assert!(!have_array_size); - assert!( - symbol_kind == SymbolKind::Other() - || symbol_kind == SymbolKind::Constant() - || symbol_kind == SymbolKind::Poly(PolynomialType::Intermediate) - ); + assert!(symbol_kind != SymbolKind::Poly(PolynomialType::Committed)); FunctionValueDefinition::Expression(self.process_expression(expr)) } - FunctionDefinition::Mapping(params, expr) => { - assert!(!have_array_size); - assert!(symbol_kind == SymbolKind::Poly(PolynomialType::Constant)); - FunctionValueDefinition::Mapping( - ExpressionProcessor::new(self).process_function(¶ms, expr), - ) - } FunctionDefinition::Query(params, expr) => { assert!(!have_array_size); assert_eq!(symbol_kind, SymbolKind::Poly(PolynomialType::Committed)); - FunctionValueDefinition::Query( - ExpressionProcessor::new(self).process_function(¶ms, expr), - ) + let body = Box::new(ExpressionProcessor::new(self).process_function(¶ms, expr)); + FunctionValueDefinition::Query(Expression::LambdaExpression(LambdaExpression { + params, + body, + })) } FunctionDefinition::Array(value) => { let size = value.solve(self.polynomial_degree.unwrap()); @@ -628,8 +605,6 @@ impl<'a, T: FieldElement> ExpressionProcessor<'a, T> { .collect(); // Re-add the outer local variables if we do not overwrite them // and increase their index by the number of parameters. - // TODO re-evaluate if this mechanism makes sense as soon as we properly - // support nested functions and closures. for (name, index) in &previous_local_vars { self.local_variables .entry(name.clone()) diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 7da299a228..47434fe3f9 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -60,7 +60,6 @@ fn remove_constant_fixed_columns(pil_file: &mut Analyzed) { /// value and returns it in that case. fn constant_value(function: &FunctionValueDefinition) -> Option { match function { - FunctionValueDefinition::Mapping(_) => None, // TODO we could also analyze this case. FunctionValueDefinition::Array(expressions) => { // TODO use a proper evaluator at some point, // combine with constant_evalutaor