From a9a83b8d8cfa14f7391a3ed840908a09e14e931e Mon Sep 17 00:00:00 2001 From: Kata Choi Date: Tue, 21 Jan 2025 17:45:57 +0700 Subject: [PATCH] =?UTF-8?q?fix:=20instead=20of=20folding=20expr=20for=20co?= =?UTF-8?q?nstant=20values,=20we=20keep=20the=20expr=20as=E2=80=A6=20(#252?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: instead of folding expr for constant values, we keep the expr as it is * fix: avoid re-monomorphized a function name --- examples/functions.no | 2 ++ src/error.rs | 3 +++ src/mast/mod.rs | 40 +++++++++++++++++++++++++++++++--------- src/parser/types.rs | 6 ++++++ 4 files changed, 42 insertions(+), 9 deletions(-) diff --git a/examples/functions.no b/examples/functions.no index d69bf196a..2a66c2cc0 100644 --- a/examples/functions.no +++ b/examples/functions.no @@ -10,6 +10,8 @@ fn main(pub one: Field) { let four = add(one, 3); assert_eq(four, 4); + // double() should not be folded to return 8 + // the asm test will catch the missing constraint if it is folded let eight = double(4); assert_eq(eight, double(four)); } diff --git a/src/error.rs b/src/error.rs index 9ac60c9a8..9d086fc2e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -372,6 +372,9 @@ pub enum ErrorKind { #[error("division by zero")] DivisionByZero, + #[error("lhs `{0}` is less than rhs `{1}`")] + NegativeLhsLessThanRhs(String, String), + #[error("Not enough variables provided to fill placeholders in the formatted string")] InsufficientVariables, } diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 0d9421780..e2c6e0dcf 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -827,20 +827,42 @@ fn monomorphize_expr( let ExprMonoInfo { expr: rhs_expr, .. } = rhs_mono; // fold constants - let cst = match (&lhs_expr.kind, &rhs_expr.kind) { - (ExprKind::BigUInt(lhs), ExprKind::BigUInt(rhs)) => match op { - Op2::Addition => Some(lhs + rhs), - Op2::Subtraction => Some(lhs - rhs), - Op2::Multiplication => Some(lhs * rhs), - Op2::Division => Some(lhs / rhs), - _ => None, - }, + let cst = match (&lhs_mono.constant, &rhs_mono.constant) { + (Some(PropagatedConstant::Single(lhs)), Some(PropagatedConstant::Single(rhs))) => { + match op { + Op2::Addition => Some(lhs + rhs), + Op2::Subtraction => { + if lhs < rhs { + // throw error + return Err(error( + ErrorKind::NegativeLhsLessThanRhs( + lhs.to_string(), + rhs.to_string(), + ), + expr.span, + )); + } + Some(lhs - rhs) + } + Op2::Multiplication => Some(lhs * rhs), + Op2::Division => Some(lhs / rhs), + _ => None, + } + } _ => None, }; match cst { Some(v) => { - let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(v.clone())); + let mexpr = expr.to_mast( + ctx, + &ExprKind::BinaryOp { + op: op.clone(), + protected: *protected, + lhs: Box::new(lhs_expr), + rhs: Box::new(rhs_expr), + }, + ); ExprMonoInfo::new(mexpr, typ, Some(PropagatedConstant::from(v))) } diff --git a/src/parser/types.rs b/src/parser/types.rs index 0e7262d60..9a3ee98ef 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -716,6 +716,12 @@ impl FnSig { pub fn monomorphized_name(&self) -> Ident { let mut name = self.name.clone(); + // check if it contains # in the name + if name.value.contains('#') { + // if so, then it is already monomorphized + return name; + } + if self.require_monomorphization() { let mut generics = self.generics.parameters.iter().collect::>(); generics.sort_by(|a, b| a.0.cmp(b.0));