Skip to content

Commit

Permalink
Merge pull request #776 from powdr-labs/remove_mapping
Browse files Browse the repository at this point in the history
Replace Mapping enum alternative by Expression.
  • Loading branch information
Leo authored Nov 27, 2023
2 parents 66cf236 + 86cf0df commit f3800c2
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 151 deletions.
3 changes: 1 addition & 2 deletions analysis/src/macro_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ where
if let PilStatement::PolynomialConstantDefinition(_, _, f)
| PilStatement::PolynomialCommitDeclaration(_, _, Some(f)) = &statement
{
if let FunctionDefinition::Mapping(params, _) | FunctionDefinition::Query(params, _) = f
{
if let FunctionDefinition::Query(params, _) = f {
assert!(self.shadowing_locals.is_empty());
self.shadowing_locals.extend(params.iter().cloned());
added_locals = true;
Expand Down
24 changes: 21 additions & 3 deletions ast/src/analyzed/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,34 @@ impl<T: Display> Display for Analyzed<T> {
impl<T: Display> Display for FunctionValueDefinition<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
match self {
FunctionValueDefinition::Mapping(e) => write!(f, "(i) {{ {e} }}"),
FunctionValueDefinition::Array(items) => {
write!(f, " = {}", items.iter().format(" + "))
}
FunctionValueDefinition::Query(e) => write!(f, "(i) query {e}"),
FunctionValueDefinition::Expression(e) => write!(f, " = {e}"),
FunctionValueDefinition::Query(e) => format_outer_function(e, Some("query"), f),
FunctionValueDefinition::Expression(e) => format_outer_function(e, None, f),
}
}
}

fn format_outer_function<T: Display>(
e: &Expression<T>,
qualifier: Option<&str>,
f: &mut Formatter<'_>,
) -> Result {
let q = qualifier.map(|s| format!(" {s}")).unwrap_or_default();
match e {
parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1 => {
let body = if q.is_empty() {
format!("{{ {} }}", lambda.body)
} else {
format!("{}", lambda.body)
};
write!(f, "({}){q} {body}", lambda.params.iter().format(", "),)
}
_ => write!(f, " ={q} {e}"),
}
}

impl<T: Display> Display for RepeatedArray<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
if self.is_empty() {
Expand Down
4 changes: 1 addition & 3 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ impl<T> Analyzed<T> {
self.definitions
.values_mut()
.for_each(|(_poly, definition)| match definition {
Some(FunctionValueDefinition::Mapping(e))
| Some(FunctionValueDefinition::Query(e)) => e.post_visit_expressions_mut(f),
Some(FunctionValueDefinition::Query(e)) => e.post_visit_expressions_mut(f),
Some(FunctionValueDefinition::Array(elements)) => elements
.iter_mut()
.flat_map(|e| e.pattern.iter_mut())
Expand Down Expand Up @@ -365,7 +364,6 @@ pub enum SymbolKind {

#[derive(Debug, Clone)]
pub enum FunctionValueDefinition<T> {
Mapping(Expression<T>),
Array(Vec<RepeatedArray<T>>),
Query(Expression<T>),
Expression(Expression<T>),
Expand Down
12 changes: 6 additions & 6 deletions ast/src/analyzed/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ impl<T> ExpressionVisitable<Expression<T>> for FunctionValueDefinition<T> {
F: FnMut(&mut Expression<T>) -> ControlFlow<B>,
{
match self {
FunctionValueDefinition::Mapping(e)
| FunctionValueDefinition::Query(e)
| FunctionValueDefinition::Expression(e) => e.visit_expressions_mut(f, o),
FunctionValueDefinition::Query(e) | FunctionValueDefinition::Expression(e) => {
e.visit_expressions_mut(f, o)
}
FunctionValueDefinition::Array(array) => array
.iter_mut()
.flat_map(|a| a.pattern.iter_mut())
Expand All @@ -101,9 +101,9 @@ impl<T> ExpressionVisitable<Expression<T>> for FunctionValueDefinition<T> {
F: FnMut(&Expression<T>) -> ControlFlow<B>,
{
match self {
FunctionValueDefinition::Mapping(e)
| FunctionValueDefinition::Query(e)
| FunctionValueDefinition::Expression(e) => e.visit_expressions(f, o),
FunctionValueDefinition::Query(e) | FunctionValueDefinition::Expression(e) => {
e.visit_expressions(f, o)
}
FunctionValueDefinition::Array(array) => array
.iter()
.flat_map(|a| a.pattern().iter())
Expand Down
17 changes: 11 additions & 6 deletions ast/src/parsed/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,18 +419,23 @@ impl<T: Display> Display for ArrayExpression<T> {
impl<T: Display> Display for FunctionDefinition<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
match self {
FunctionDefinition::Mapping(params, body) => {
write!(f, "({}) {{ {body} }}", params.join(", "))
}
FunctionDefinition::Array(array_expression) => {
write!(f, " = {array_expression}")
}
FunctionDefinition::Query(params, value) => {
write!(f, "({}) query {value}", params.join(", "),)
}
FunctionDefinition::Expression(e) => {
write!(f, " = {e}")
}
FunctionDefinition::Expression(e) => match e {
Expression::LambdaExpression(lambda) if lambda.params.len() == 1 => {
write!(
f,
"({}) {{ {} }}",
lambda.params.iter().format(", "),
lambda.body
)
}
_ => write!(f, " = {e}"),
},
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions ast/src/parsed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,11 @@ pub enum MatchPattern<T, Ref = NamespacedPolynomialReference> {
/// The definition of a function (excluding its name):
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub enum FunctionDefinition<T> {
/// Parameter-value-mapping.
Mapping(Vec<String>, Expression<T>),
/// Array expression.
Array(ArrayExpression<T>),
/// Prover query.
Query(Vec<String>, Expression<T>),
/// Expression, for intermediate polynomials
/// Generic expression
Expression(Expression<T>),
}

Expand Down
8 changes: 2 additions & 6 deletions ast/src/parsed/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,7 @@ impl<T> ExpressionVisitable<Expression<T>> for FunctionDefinition<T> {
F: FnMut(&mut Expression<T>) -> ControlFlow<B>,
{
match self {
FunctionDefinition::Query(_, e) | FunctionDefinition::Mapping(_, e) => {
e.visit_expressions_mut(f, o)
}
FunctionDefinition::Query(_, e) => e.visit_expressions_mut(f, o),
FunctionDefinition::Array(ae) => ae.visit_expressions_mut(f, o),
FunctionDefinition::Expression(e) => e.visit_expressions_mut(f, o),
}
Expand All @@ -304,9 +302,7 @@ impl<T> ExpressionVisitable<Expression<T>> for FunctionDefinition<T> {
F: FnMut(&Expression<T>) -> ControlFlow<B>,
{
match self {
FunctionDefinition::Query(_, e) | FunctionDefinition::Mapping(_, e) => {
e.visit_expressions(f, o)
}
FunctionDefinition::Query(_, e) => e.visit_expressions(f, o),
FunctionDefinition::Array(ae) => ae.visit_expressions(f, o),
FunctionDefinition::Expression(e) => e.visit_expressions(f, o),
}
Expand Down
18 changes: 6 additions & 12 deletions executor/src/constant_evaluator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{collections::HashMap, fmt::Display, rc::Rc};
use ast::analyzed::{Analyzed, FunctionValueDefinition};
use itertools::Itertools;
use number::{DegreeType, FieldElement};
use pil_analyzer::evaluator::{self, Closure, Custom, EvalError, SymbolLookup, Value};
use pil_analyzer::evaluator::{self, Custom, EvalError, SymbolLookup, Value};
use rayon::prelude::{IntoParallelIterator, ParallelIterator};

/// Generates the constant polynomial values for all constant polynomials
Expand Down Expand Up @@ -36,10 +36,13 @@ fn generate_values<T: FieldElement>(
};
// TODO we should maybe pre-compute some symbols here.
match body {
FunctionValueDefinition::Mapping(body) => (0..degree)
FunctionValueDefinition::Expression(e) => (0..degree)
.into_par_iter()
.map(|i| {
evaluator::evaluate_function_call(body, vec![i.into()], &symbols)
// We could try to avoid the first evaluation to be run for each iteration,
// but the data is not thread-safe.
let fun = evaluator::evaluate(e, &symbols).unwrap();
evaluator::evaluate_function_call(fun, vec![Rc::new(T::from(i).into())], &symbols)
.unwrap()
.try_to_number()
.unwrap()
Expand Down Expand Up @@ -71,9 +74,6 @@ fn generate_values<T: FieldElement>(
values
}
FunctionValueDefinition::Query(_) => panic!("Query used for fixed column."),
FunctionValueDefinition::Expression(_) => {
panic!("Expression used for fixed column, only expected for intermediate polynomials")
}
}
}

Expand All @@ -92,12 +92,6 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, FixedColumnRef<'a>> for Symbols<'a
Some(FunctionValueDefinition::Expression(value)) => {
evaluator::evaluate(value, self)?
}
Some(FunctionValueDefinition::Mapping(body)) => (Closure {
parameter_count: 1,
body,
environment: vec![],
})
.into(),
Some(_) => Err(EvalError::Unsupported(
"Cannot evaluate arrays and queries.".to_string(),
))?,
Expand Down
17 changes: 7 additions & 10 deletions executor/src/witgen/query_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,13 @@ impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback<T>>
query: &'a Expression<T>,
rows: &RowPair<T>,
) -> Result<String, EvalError> {
let arguments = vec![rows.current_row_index.into()];
evaluator::evaluate_function_call(
query,
arguments,
&Symbols {
fixed_data: self.fixed_data,
rows,
},
)
.map(|v| v.to_string())
let arguments = vec![Rc::new(T::from(rows.current_row_index).into())];
let symbols = Symbols {
fixed_data: self.fixed_data,
rows,
};
let fun = evaluator::evaluate(query, &symbols)?;
evaluator::evaluate_function_call(fun, arguments, &symbols).map(|v| v.to_string())
}
}

Expand Down
2 changes: 1 addition & 1 deletion parser/src/powdr.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ PolynomialConstantDefinition: PilStatement<T> = {
}

FunctionDefinition: FunctionDefinition<T> = {
"(" <ParameterList> ")" "{" <Expression> "}" => FunctionDefinition::Mapping(<>),
"(" <params:ParameterList> ")" "{" <body:BoxedExpression> "}" => FunctionDefinition::Expression(Expression::LambdaExpression(LambdaExpression{params, body})),
"=" <ArrayLiteralExpression> => FunctionDefinition::Array(<>),
}

Expand Down
88 changes: 38 additions & 50 deletions pil_analyzer/src/evaluator.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::{collections::HashMap, fmt::Display, iter::repeat, rc::Rc};
use std::{collections::HashMap, fmt::Display, rc::Rc};

use ast::{
analyzed::{Expression, FunctionValueDefinition, Reference, Symbol},
evaluate_binary_operation, evaluate_unary_operation,
parsed::{display::quote, FunctionCall, MatchArm, MatchPattern},
parsed::{display::quote, FunctionCall, LambdaExpression, MatchArm, MatchPattern},
};
use itertools::Itertools;
use number::FieldElement;
Expand All @@ -16,28 +16,42 @@ pub fn evaluate_expression<'a, T: FieldElement>(
evaluate(expr, &Definitions(definitions))
}

/// Evaluates an expression given a symbol lookup implementation
pub fn evaluate<'a, T: FieldElement, C: Custom>(
expr: &'a Expression<T>,
symbols: &impl SymbolLookup<'a, T, C>,
) -> Result<Value<'a, T, C>, EvalError> {
internal::evaluate(expr, &[], symbols)
}

// TODO this function should be removed in the future, or at least it should
// check that `expr` actually evaluates to a Closure.
/// Evaluates a function call.
pub fn evaluate_function_call<'a, T: FieldElement, C: Custom>(
expr: &'a Expression<T>,
arguments: Vec<T>,
function: Value<'a, T, C>,
arguments: Vec<Rc<Value<'a, T, C>>>,
symbols: &impl SymbolLookup<'a, T, C>,
) -> Result<Value<'a, T, C>, EvalError> {
internal::evaluate(
expr,
&arguments
.into_iter()
.map(|x| Rc::new(Value::Number(x)))
.collect::<Vec<_>>(),
symbols,
)
match function {
Value::Closure(Closure {
lambda,
environment,
}) => {
if lambda.params.len() != arguments.len() {
Err(EvalError::TypeError(format!(
"Invalid function call: Supplied {} arguments to function that takes {} parameters.",
arguments.len(),
lambda.params.len())
))?
}

let local_vars = arguments.into_iter().chain(environment).collect::<Vec<_>>();

internal::evaluate(&lambda.body, &local_vars, symbols)
}
Value::Custom(value) => symbols.eval_function_application(value, &arguments),
e => Err(EvalError::TypeError(format!(
"Expected function but got {e}"
))),
}
}

/// Evaluation errors.
Expand Down Expand Up @@ -69,6 +83,12 @@ pub enum Value<'a, T, C> {
Custom(C),
}

impl<'a, T, C> From<T> for Value<'a, T, C> {
fn from(value: T) -> Self {
Value::Number(value)
}
}

// TODO somehow, implementing TryFrom or TryInto did not work.

impl<'a, T: FieldElement, C: Custom> Value<'a, T, C> {
Expand Down Expand Up @@ -108,10 +128,7 @@ impl Display for NoCustom {

#[derive(Clone)]
pub struct Closure<'a, T, C> {
// TODO we could also store the names of the parameters (for printing)
// In order to do this, we would need to add a parameter name to Mapping, which might be a good idea anyway.
pub parameter_count: usize,
pub body: &'a Expression<T>,
pub lambda: &'a LambdaExpression<T, Reference>,
pub environment: Vec<Rc<Value<'a, T, C>>>,
}

Expand All @@ -125,12 +142,7 @@ impl<'a, T, C> PartialEq for Closure<'a, T, C> {

impl<'a, T: Display, C> Display for Closure<'a, T, C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"|{}| {}",
repeat('_').take(self.parameter_count).format(", "),
self.body
)
write!(f, "{}", self.lambda)
}
}

Expand All @@ -149,12 +161,6 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Definitions<'a, T> {
Ok(match self.0.get(&name.to_string()) {
Some((_, value)) => match value {
Some(FunctionValueDefinition::Expression(value)) => evaluate(value, self)?,
Some(FunctionValueDefinition::Mapping(body)) => (Closure {
parameter_count: 1,
body,
environment: vec![],
})
.into(),
_ => Err(EvalError::Unsupported(
"Cannot evaluate arrays and queries.".to_string(),
))?,
Expand Down Expand Up @@ -235,8 +241,7 @@ mod internal {
Expression::LambdaExpression(lambda) => {
// TODO only copy the part of the environment that is actually referenced?
(Closure {
parameter_count: lambda.params.len(),
body: lambda.body.as_ref(),
lambda,
environment: locals.to_vec(),
})
.into()
Expand Down Expand Up @@ -266,24 +271,7 @@ mod internal {
.iter()
.map(|a| evaluate(a, locals, symbols).map(Rc::new))
.collect::<Result<Vec<_>, _>>()?;
match function {
Value::Closure(Closure {
parameter_count,
body,
environment,
}) => {
assert_eq!(parameter_count, arguments.len());

let local_vars =
arguments.into_iter().chain(environment).collect::<Vec<_>>();

evaluate(body, &local_vars, symbols)?
}
Value::Custom(value) => symbols.eval_function_application(value, &arguments)?,
e => Err(EvalError::TypeError(format!(
"Expected function but got {e}"
)))?,
}
evaluate_function_call(function, arguments, symbols)?
}
Expression::MatchExpression(scrutinee, arms) => {
let v = evaluate(scrutinee, locals, symbols)?;
Expand Down
Loading

0 comments on commit f3800c2

Please sign in to comment.