Skip to content

Commit

Permalink
Restrict operators on AlgebraicExpression.
Browse files Browse the repository at this point in the history
  • Loading branch information
chriseth committed Nov 1, 2023
1 parent a148b0e commit 6eb5982
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 118 deletions.
19 changes: 17 additions & 2 deletions ast/src/analyzed/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,11 @@ impl<T: Display> Display for Identity<AlgebraicExpression<T>> {
match self.kind {
IdentityKind::Polynomial => {
let expression = self.expression_for_poly_id();
if let AlgebraicExpression::BinaryOperation(left, BinaryOperator::Sub, right) =
expression
if let AlgebraicExpression::BinaryOperation(
left,
AlgebraicBinaryOperator::Sub,
right,
) = expression
{
write!(f, "{left} = {right};")
} else {
Expand Down Expand Up @@ -189,6 +192,18 @@ impl<T: Display> Display for AlgebraicExpression<T> {
}
}

impl Display for AlgebraicUnaryOperator {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
UnaryOperator::from(*self).fmt(f)
}
}

impl Display for AlgebraicBinaryOperator {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
BinaryOperator::from(*self).fmt(f)
}
}

impl Display for AlgebraicReference {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
Expand Down
78 changes: 72 additions & 6 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -487,14 +487,80 @@ pub enum AlgebraicExpression<T> {
Number(T),
BinaryOperation(
Box<AlgebraicExpression<T>>,
BinaryOperator,
AlgebraicBinaryOperator,
Box<AlgebraicExpression<T>>,
),
UnaryOperation(UnaryOperator, Box<AlgebraicExpression<T>>),

UnaryOperation(AlgebraicUnaryOperator, Box<AlgebraicExpression<T>>),
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub enum AlgebraicBinaryOperator {
Add,
Sub,
Mul,
/// Exponentiation, but only by constants.
Pow,
}

impl From<AlgebraicBinaryOperator> for BinaryOperator {
fn from(op: AlgebraicBinaryOperator) -> BinaryOperator {
match op {
AlgebraicBinaryOperator::Add => BinaryOperator::Add,
AlgebraicBinaryOperator::Sub => BinaryOperator::Sub,
AlgebraicBinaryOperator::Mul => BinaryOperator::Mul,
AlgebraicBinaryOperator::Pow => BinaryOperator::Pow,
}
}
}

impl TryFrom<BinaryOperator> for AlgebraicBinaryOperator {
type Error = String;

fn try_from(op: BinaryOperator) -> Result<Self, Self::Error> {
match op {
BinaryOperator::Add => Ok(AlgebraicBinaryOperator::Add),
BinaryOperator::Sub => Ok(AlgebraicBinaryOperator::Sub),
BinaryOperator::Mul => Ok(AlgebraicBinaryOperator::Mul),
BinaryOperator::Pow => Ok(AlgebraicBinaryOperator::Pow),
_ => Err(format!(
"Binary operator {op} not allowed in algebraic expression."
)),
}
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub enum AlgebraicUnaryOperator {
Plus,
Minus,
}

impl From<AlgebraicUnaryOperator> for UnaryOperator {
fn from(op: AlgebraicUnaryOperator) -> UnaryOperator {
match op {
AlgebraicUnaryOperator::Plus => UnaryOperator::Plus,
AlgebraicUnaryOperator::Minus => UnaryOperator::Minus,
}
}
}

impl TryFrom<UnaryOperator> for AlgebraicUnaryOperator {
type Error = String;

fn try_from(op: UnaryOperator) -> Result<Self, Self::Error> {
match op {
UnaryOperator::Plus => Ok(AlgebraicUnaryOperator::Plus),
UnaryOperator::Minus => Ok(AlgebraicUnaryOperator::Minus),
_ => Err(format!(
"Unary operator {op} not allowed in algebraic expression."
)),
}
}
}

impl<T> AlgebraicExpression<T> {
pub fn new_binary(left: Self, op: BinaryOperator, right: Self) -> Self {
pub fn new_binary(left: Self, op: AlgebraicBinaryOperator, right: Self) -> Self {
AlgebraicExpression::BinaryOperation(Box::new(left), op, Box::new(right))
}

Expand Down Expand Up @@ -528,23 +594,23 @@ impl<T> ops::Add for AlgebraicExpression<T> {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
Self::new_binary(self, BinaryOperator::Add, rhs)
Self::new_binary(self, AlgebraicBinaryOperator::Add, rhs)
}
}

impl<T> ops::Sub for AlgebraicExpression<T> {
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
Self::new_binary(self, BinaryOperator::Sub, rhs)
Self::new_binary(self, AlgebraicBinaryOperator::Sub, rhs)
}
}

impl<T> ops::Mul for AlgebraicExpression<T> {
type Output = Self;

fn mul(self, rhs: Self) -> Self::Output {
Self::new_binary(self, BinaryOperator::Mul, rhs)
Self::new_binary(self, AlgebraicBinaryOperator::Mul, rhs)
}
}

Expand Down
35 changes: 9 additions & 26 deletions backend/src/pilstark/json_exporter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use std::cmp;
use std::collections::HashMap;

use ast::analyzed::{
AlgebraicExpression as Expression, AlgebraicReference, Analyzed, BinaryOperator, IdentityKind,
PolyID, PolynomialType, StatementIdentifier, SymbolKind, UnaryOperator,
AlgebraicBinaryOperator, AlgebraicExpression as Expression, AlgebraicReference,
AlgebraicUnaryOperator, Analyzed, IdentityKind, PolyID, PolynomialType, StatementIdentifier,
SymbolKind,
};
use starky::types::{
ConnectionIdentity, Expression as StarkyExpr, PermutationIdentity, PlookupIdentity,
Expand Down Expand Up @@ -284,34 +285,17 @@ impl<'a, T: FieldElement> Exporter<'a, T> {
let (deg_left, left) = self.expression_to_json(left);
let (deg_right, right) = self.expression_to_json(right);
let (op, degree) = match op {
BinaryOperator::Add => ("add", cmp::max(deg_left, deg_right)),
BinaryOperator::Sub => ("sub", cmp::max(deg_left, deg_right)),
BinaryOperator::Mul => ("mul", deg_left + deg_right),
BinaryOperator::Div => panic!("Div is not really allowed"),
BinaryOperator::Pow => {
AlgebraicBinaryOperator::Add => ("add", cmp::max(deg_left, deg_right)),
AlgebraicBinaryOperator::Sub => ("sub", cmp::max(deg_left, deg_right)),
AlgebraicBinaryOperator::Mul => ("mul", deg_left + deg_right),
AlgebraicBinaryOperator::Pow => {
assert_eq!(
deg_left + deg_right,
0,
"Exponentiation can only be used on constants."
);
("pow", deg_left + deg_right)
}
BinaryOperator::Mod
| BinaryOperator::BinaryAnd
| BinaryOperator::BinaryOr
| BinaryOperator::BinaryXor
| BinaryOperator::ShiftLeft
| BinaryOperator::ShiftRight
| BinaryOperator::LogicalOr
| BinaryOperator::LogicalAnd
| BinaryOperator::Less
| BinaryOperator::LessEqual
| BinaryOperator::Equal
| BinaryOperator::NotEqual
| BinaryOperator::GreaterEqual
| BinaryOperator::Greater => {
panic!("Operator {op:?} not supported on polynomials.")
}
};
(
degree,
Expand All @@ -326,8 +310,8 @@ impl<'a, T: FieldElement> Exporter<'a, T> {
Expression::UnaryOperation(op, value) => {
let (deg, value) = self.expression_to_json(value);
match op {
UnaryOperator::Plus => (deg, value),
UnaryOperator::Minus => (
AlgebraicUnaryOperator::Plus => (deg, value),
AlgebraicUnaryOperator::Minus => (
deg,
StarkyExpr {
op: "neg".to_string(),
Expand All @@ -336,7 +320,6 @@ impl<'a, T: FieldElement> Exporter<'a, T> {
..DEFAULT_EXPR
},
),
UnaryOperator::LogicalNot => panic!("Operator {op} not allowed here."),
}
}
}
Expand Down
69 changes: 15 additions & 54 deletions executor/src/witgen/expression_evaluator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::marker::PhantomData;

use ast::analyzed::{AlgebraicExpression as Expression, AlgebraicReference};
use ast::parsed::{BinaryOperator, UnaryOperator};
use ast::analyzed::{
AlgebraicBinaryOperator, AlgebraicExpression as Expression, AlgebraicReference,
AlgebraicUnaryOperator,
};

use number::FieldElement;

use super::{affine_expression::AffineResult, IncompleteCause};
Expand Down Expand Up @@ -49,13 +52,13 @@ where
fn evaluate_binary_operation<'a>(
&self,
left: &'a Expression<T>,
op: &BinaryOperator,
op: &AlgebraicBinaryOperator,
right: &'a Expression<T>,
) -> AffineResult<&'a AlgebraicReference, T> {
let left = self.evaluate(left);

// Short-circuit multiplication by zero.
if *op == BinaryOperator::Mul {
if *op == AlgebraicBinaryOperator::Mul {
if let Ok(zero) = &left {
if zero.constant_value().map(|z| z.is_zero()) == Some(true) {
return Ok(zero.clone());
Expand All @@ -66,15 +69,15 @@ where

match (left, op, right) {
// Short-circuit multiplication by zero for "right".
(_, BinaryOperator::Mul, Ok(zero))
(_, AlgebraicBinaryOperator::Mul, Ok(zero))
if zero.constant_value().map(|z| z.is_zero()) == Some(true) =>
{
Ok(zero)
}
(Ok(left), op, Ok(right)) => match op {
BinaryOperator::Add => Ok(left + right),
BinaryOperator::Sub => Ok(left - right),
BinaryOperator::Mul => {
AlgebraicBinaryOperator::Add => Ok(left + right),
AlgebraicBinaryOperator::Sub => Ok(left - right),
AlgebraicBinaryOperator::Mul => {
if let Some(f) = left.constant_value() {
Ok(right * f)
} else if let Some(f) = right.constant_value() {
Expand All @@ -83,48 +86,13 @@ where
Err(IncompleteCause::QuadraticTerm)
}
}
BinaryOperator::Div => {
if let (Some(l), Some(r)) = (left.constant_value(), right.constant_value()) {
// TODO Maybe warn about division by zero here.
if l.is_zero() {
Ok(T::zero().into())
} else {
// TODO We have to do division in the proper field.
Ok((l / r).into())
}
} else {
Err(IncompleteCause::DivisionTerm)
}
}
BinaryOperator::Pow => {
AlgebraicBinaryOperator::Pow => {
if let (Some(l), Some(r)) = (left.constant_value(), right.constant_value()) {
Ok(l.pow(r.to_integer()).into())
} else {
Err(IncompleteCause::ExponentiationTerm)
}
}
BinaryOperator::Mod
| BinaryOperator::BinaryAnd
| BinaryOperator::BinaryXor
| BinaryOperator::BinaryOr
| BinaryOperator::ShiftLeft
| BinaryOperator::ShiftRight
| BinaryOperator::LogicalOr
| BinaryOperator::LogicalAnd
| BinaryOperator::Less
| BinaryOperator::LessEqual
| BinaryOperator::Equal
| BinaryOperator::NotEqual
| BinaryOperator::GreaterEqual
| BinaryOperator::Greater => {
if let (Some(left), Some(right)) =
(left.constant_value(), right.constant_value())
{
Ok(ast::evaluate_binary_operation(left, *op, right).into())
} else {
panic!()
}
}
},
(Ok(_), _, Err(reason)) | (Err(reason), _, Ok(_)) => Err(reason),
(Err(r1), _, Err(r2)) => Err(r1.combine(r2)),
Expand All @@ -133,19 +101,12 @@ where

fn evaluate_unary_operation<'a>(
&self,
op: &UnaryOperator,
op: &AlgebraicUnaryOperator,
expr: &'a Expression<T>,
) -> AffineResult<&'a AlgebraicReference, T> {
self.evaluate(expr).map(|v| match op {
UnaryOperator::Plus => v,
UnaryOperator::Minus => -v,
UnaryOperator::LogicalNot => {
if let Some(v) = v.constant_value() {
ast::evaluate_unary_operation(*op, v).into()
} else {
panic!()
}
}
AlgebraicUnaryOperator::Plus => v,
AlgebraicUnaryOperator::Minus => -v,
})
}
}
10 changes: 5 additions & 5 deletions executor/src/witgen/global_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use std::collections::{BTreeMap, BTreeSet};
use num_traits::Zero;

use ast::analyzed::{
AlgebraicExpression as Expression, AlgebraicReference, Identity, IdentityKind, PolyID,
PolynomialType,
AlgebraicBinaryOperator, AlgebraicExpression as Expression, AlgebraicReference, Identity,
IdentityKind, PolyID, PolynomialType,
};
use ast::parsed::BinaryOperator;

use number::FieldElement;

use crate::witgen::data_structures::column_map::{FixedColumnMap, WitnessColumnMap};
Expand Down Expand Up @@ -219,13 +219,13 @@ fn propagate_constraints<T: FieldElement>(
/// Tries to find "X * (1 - X) = 0"
fn is_binary_constraint<T: FieldElement>(expr: &Expression<T>) -> Option<PolyID> {
// TODO Write a proper pattern matching engine.
if let Expression::BinaryOperation(left, BinaryOperator::Sub, right) = expr {
if let Expression::BinaryOperation(left, AlgebraicBinaryOperator::Sub, right) = expr {
if let Expression::Number(n) = right.as_ref() {
if n.is_zero() {
return is_binary_constraint(left.as_ref());
}
}
} else if let Expression::BinaryOperation(left, BinaryOperator::Mul, right) = expr {
} else if let Expression::BinaryOperation(left, AlgebraicBinaryOperator::Mul, right) = expr {
let symbolic_ev = SymbolicEvaluator;
let left_root = ExpressionEvaluator::new(symbolic_ev.clone())
.evaluate(left)
Expand Down
Loading

0 comments on commit 6eb5982

Please sign in to comment.