diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index ac5ca3d01..8d87fe537 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -12,10 +12,7 @@ use crate::{ var::{ConstOrCell, Value, Var}, }; -use super::{FnInfoType, Module}; - -const NTH_BIT_FN: &str = "nth_bit(val: Field, const nth: Field) -> Field"; -const CHECK_FIELD_SIZE_FN: &str = "check_field_size(cmp: Field)"; +use super::{builtins::Builtin, FnInfoType, Module}; pub struct BitsLib {} @@ -24,81 +21,95 @@ impl Module for BitsLib { fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { vec![ - (NTH_BIT_FN, nth_bit, false), - (CHECK_FIELD_SIZE_FN, check_field_size, false), + (NthBitFn::SIGNATURE, NthBitFn::builtin, false), + ( + CheckFieldSizeFn::SIGNATURE, + CheckFieldSizeFn::builtin, + false, + ), ] } } -fn nth_bit( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // should be two input vars - assert_eq!(vars.len(), 2); - - // these should be type checked already, unless it is called by other low level functions - // eg. builtins - let var_info = &vars[0]; - let val = &var_info.var; - assert_eq!(val.len(), 1); - - let var_info = &vars[1]; - let nth = &var_info.var; - assert_eq!(nth.len(), 1); - - let nth: usize = match &nth[0] { - ConstOrCell::Cell(_) => unreachable!("nth should be a constant"), - ConstOrCell::Const(cst) => cst.to_u64() as usize, - }; - - let val = match &val[0] { - ConstOrCell::Cell(cvar) => cvar.clone(), - ConstOrCell::Const(cst) => { - // directly return the nth bit without adding symbolic value as it doesn't depend on a cell var - let bit = cst.to_bits(); - return Ok(Some(Var::new_cvar( - ConstOrCell::Const(B::Field::from(bit[nth])), - span, - ))); - } - }; +struct NthBitFn {} +struct CheckFieldSizeFn {} + +impl Builtin for NthBitFn { + const SIGNATURE: &'static str = "nth_bit(val: Field, const nth: Field) -> Field"; + + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + // should be two input vars + assert_eq!(vars.len(), 2); + + // these should be type checked already, unless it is called by other low level functions + // eg. builtins + let var_info = &vars[0]; + let val = &var_info.var; + assert_eq!(val.len(), 1); + + let var_info = &vars[1]; + let nth = &var_info.var; + assert_eq!(nth.len(), 1); + + let nth: usize = match &nth[0] { + ConstOrCell::Cell(_) => unreachable!("nth should be a constant"), + ConstOrCell::Const(cst) => cst.to_u64() as usize, + }; + + let val = match &val[0] { + ConstOrCell::Cell(cvar) => cvar.clone(), + ConstOrCell::Const(cst) => { + // directly return the nth bit without adding symbolic value as it doesn't depend on a cell var + let bit = cst.to_bits(); + return Ok(Some(Var::new_cvar( + ConstOrCell::Const(B::Field::from(bit[nth])), + span, + ))); + } + }; - let bit = compiler - .backend - .new_internal_var(Value::NthBit(val.clone(), nth), span); + let bit = compiler + .backend + .new_internal_var(Value::NthBit(val.clone(), nth), span); - Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) + Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) + } } -// Ensure that the field size is not exceeded -fn check_field_size( - _compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - let var = &vars[0].var[0]; - let bit_len = B::Field::MODULUS_BIT_SIZE as u64; - - match var { - ConstOrCell::Const(cst) => { - let to_cmp = cst.to_u64(); - if to_cmp >= bit_len { - return Err(Error::new( - "constraint-generation", - ErrorKind::AssertionFailed, - span, - )); +impl Builtin for CheckFieldSizeFn { + const SIGNATURE: &'static str = "check_field_size(cmp: Field)"; + + fn builtin( + _compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + let var = &vars[0].var[0]; + let bit_len = B::Field::MODULUS_BIT_SIZE as u64; + + match var { + ConstOrCell::Const(cst) => { + let to_cmp = cst.to_u64(); + if to_cmp >= bit_len { + return Err(Error::new( + "constraint-generation", + ErrorKind::AssertionFailed, + span, + )); + } + Ok(None) } - Ok(None) + ConstOrCell::Cell(_) => Err(Error::new( + "constraint-generation", + ErrorKind::ExpectedConstant, + span, + )), } - ConstOrCell::Cell(_) => Err(Error::new( - "constraint-generation", - ErrorKind::ExpectedConstant, - span, - )), } } diff --git a/src/stdlib/builtins.rs b/src/stdlib/builtins.rs index c487237ac..701d0260d 100644 --- a/src/stdlib/builtins.rs +++ b/src/stdlib/builtins.rs @@ -23,9 +23,6 @@ use super::{FnInfoType, Module}; pub const QUALIFIED_BUILTINS: &str = "std/builtins"; pub const BUILTIN_FN_NAMES: [&str; 3] = ["assert", "assert_eq", "log"]; -const ASSERT_FN: &str = "assert(condition: Bool)"; -const ASSERT_EQ_FN: &str = "assert_eq(lhs: Field, rhs: Field)"; -const LOG_FN: &str = "log(var: Field)"; pub struct BuiltinsLib {} impl Module for BuiltinsLib { @@ -33,10 +30,10 @@ impl Module for BuiltinsLib { fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { vec![ - (ASSERT_FN, assert_fn, false), - (ASSERT_EQ_FN, assert_eq_fn, true), + (AssertFn::SIGNATURE, AssertFn::builtin, false), + (AssertEqFn::SIGNATURE, AssertEqFn::builtin, true), // true -> skip argument type checking for log - (LOG_FN, log_fn, true), + (LogFn::SIGNATURE, LogFn::builtin, true), ] } } @@ -63,7 +60,7 @@ fn assert_eq_values( match typ { // Field and Bool has the same logic - TyKind::Field { .. } | TyKind::Bool => { + TyKind::Field { .. } | TyKind::Bool | TyKind::String(..) => { let lhs_var = &lhs_info.var[0]; let rhs_var = &rhs_info.var[0]; match (lhs_var, rhs_var) { @@ -146,114 +143,143 @@ fn assert_eq_values( comparisons } -/// Asserts that two vars are equal. -fn assert_eq_fn( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // we get two vars - assert_eq!(vars.len(), 2); - let lhs_info = &vars[0]; - let rhs_info = &vars[1]; - - // get types of both arguments - let lhs_type = lhs_info.typ.as_ref().ok_or_else(|| { - Error::new( - "constraint-generation", - ErrorKind::UnexpectedError("No type info for lhs of assertion"), - span, - ) - })?; - - let rhs_type = rhs_info.typ.as_ref().ok_or_else(|| { - Error::new( - "constraint-generation", - ErrorKind::UnexpectedError("No type info for rhs of assertion"), - span, - ) - })?; - - // they have the same type - if !lhs_type.match_expected(rhs_type, false) { - return Err(Error::new( - "constraint-generation", - ErrorKind::AssertEqTypeMismatch(lhs_type.clone(), rhs_type.clone()), - span, - )); - } +pub trait Builtin { + const SIGNATURE: &'static str; - // first collect all comparisons needed - let comparisons = assert_eq_values(compiler, lhs_info, rhs_info, lhs_type, span); + fn builtin( + compiler: &mut CircuitWriter, + generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>>; +} - // then add all the constraints - for comparison in comparisons { - match comparison { - Comparison::Vars(lhs, rhs) => { - compiler.backend.assert_eq_var(&lhs, &rhs, span); - } - Comparison::VarConst(var, constant) => { - compiler.backend.assert_eq_const(&var, constant, span); - } - Comparison::Constants(a, b) => { - if a != b { - return Err(Error::new( - "constraint-generation", - ErrorKind::AssertionFailed, - span, - )); +struct AssertEqFn {} +struct AssertFn {} +struct LogFn {} + +impl Builtin for AssertEqFn { + const SIGNATURE: &'static str = "assert_eq(lhs: Field, rhs: Field)"; + + /// Asserts that two vars are equal. + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + // we get two vars + assert_eq!(vars.len(), 2); + let lhs_info = &vars[0]; + let rhs_info = &vars[1]; + + // get types of both arguments + let lhs_type = lhs_info.typ.as_ref().ok_or_else(|| { + Error::new( + "constraint-generation", + ErrorKind::UnexpectedError("No type info for lhs of assertion"), + span, + ) + })?; + + let rhs_type = rhs_info.typ.as_ref().ok_or_else(|| { + Error::new( + "constraint-generation", + ErrorKind::UnexpectedError("No type info for rhs of assertion"), + span, + ) + })?; + + // they have the same type + if !lhs_type.match_expected(rhs_type, false) { + return Err(Error::new( + "constraint-generation", + ErrorKind::AssertEqTypeMismatch(lhs_type.clone(), rhs_type.clone()), + span, + )); + } + + // first collect all comparisons needed + let comparisons = assert_eq_values(compiler, lhs_info, rhs_info, lhs_type, span); + + // then add all the constraints + for comparison in comparisons { + match comparison { + Comparison::Vars(lhs, rhs) => { + compiler.backend.assert_eq_var(&lhs, &rhs, span); + } + Comparison::VarConst(var, constant) => { + compiler.backend.assert_eq_const(&var, constant, span); + } + Comparison::Constants(a, b) => { + if a != b { + return Err(Error::new( + "constraint-generation", + ErrorKind::AssertionFailed, + span, + )); + } } } } - } - Ok(None) + Ok(None) + } } -/// Asserts that a condition is true. -fn assert_fn( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // we get a single var - assert_eq!(vars.len(), 1); - - // of type bool - let var_info = &vars[0]; - assert!(matches!(var_info.typ, Some(TyKind::Bool))); - - // of only one field element - let var = &var_info.var; - assert_eq!(var.len(), 1); - let cond = &var[0]; - - match cond { - ConstOrCell::Const(cst) => { - assert!(cst.is_one()); - } - ConstOrCell::Cell(cvar) => { - let one = B::Field::one(); - compiler.backend.assert_eq_const(cvar, one, span); +impl Builtin for AssertFn { + const SIGNATURE: &'static str = "assert(condition: Bool)"; + + /// Asserts that a condition is true. + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result::Field, ::Var>>> { + // we get a single var + assert_eq!(vars.len(), 1); + + // of type bool + let var_info = &vars[0]; + assert!(matches!(var_info.typ, Some(TyKind::Bool))); + + // of only one field element + let var = &var_info.var; + assert_eq!(var.len(), 1); + let cond = &var[0]; + + match cond { + ConstOrCell::Const(cst) => { + assert!(cst.is_one()); + } + ConstOrCell::Cell(cvar) => { + let one = B::Field::one(); + compiler.backend.assert_eq_const(cvar, one, span); + } } - } - Ok(None) + Ok(None) + } } -/// Logging -fn log_fn( - compiler: &mut CircuitWriter, - generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - for var in vars { - // todo: will need to support string argument in order to customize msg - compiler.backend.log_var(var, span); - } +impl Builtin for LogFn { + // todo: currently only supports a single field var + // to support all the types, we can bypass the type check for this log function for now + const SIGNATURE: &'static str = "log(var: Field)"; + + /// Logging + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + for var in vars { + // todo: will need to support string argument in order to customize msg + compiler.backend.log_var(var, span); + } - Ok(None) + Ok(None) + } } diff --git a/src/stdlib/crypto.rs b/src/stdlib/crypto.rs index 66113cddd..13ff91a86 100644 --- a/src/stdlib/crypto.rs +++ b/src/stdlib/crypto.rs @@ -1,14 +1,27 @@ -use super::{FnInfoType, Module}; +use super::{builtins::Builtin, FnInfoType, Module}; use crate::backends::Backend; -const POSEIDON_FN: &str = "poseidon(input: [Field; 2]) -> [Field; 3]"; - pub struct CryptoLib {} impl Module for CryptoLib { const MODULE: &'static str = "crypto"; fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { - vec![(POSEIDON_FN, B::poseidon(), false)] + vec![(PoseidonFn::SIGNATURE, PoseidonFn::builtin, false)] + } +} + +struct PoseidonFn {} + +impl Builtin for PoseidonFn { + const SIGNATURE: &'static str = "poseidon(input: [Field; 2]) -> [Field; 3]"; + + fn builtin( + compiler: &mut crate::circuit_writer::CircuitWriter, + generics: &crate::parser::types::GenericParameters, + vars: &[crate::circuit_writer::VarInfo], + span: crate::constants::Span, + ) -> crate::error::Result>> { + B::poseidon()(compiler, generics, vars, span) } } diff --git a/src/stdlib/int.rs b/src/stdlib/int.rs index 03c574890..94f40ff89 100644 --- a/src/stdlib/int.rs +++ b/src/stdlib/int.rs @@ -11,9 +11,7 @@ use crate::{ var::{ConstOrCell, Value, Var}, }; -use super::{FnInfoType, Module}; - -const DIVMOD_FN: &str = "divmod(dividend: Field, divisor: Field) -> [Field; 2]"; +use super::{builtins::Builtin, FnInfoType, Module}; pub struct IntLib {} @@ -21,57 +19,63 @@ impl Module for IntLib { const MODULE: &'static str = "int"; fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { - vec![(DIVMOD_FN, divmod_fn, false)] + vec![(DivmodFn::SIGNATURE, DivmodFn::builtin, false)] } } /// Divides two field elements and returns the quotient and remainder. -fn divmod_fn( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // we get two vars - let dividend_info = &vars[0]; - let divisor_info = &vars[1]; - - // retrieve the values - let dividend_var = ÷nd_info.var[0]; - let divisor_var = &divisor_info.var[0]; - - match (dividend_var, divisor_var) { - // two constants - (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { - // convert to bigints - let a = a.to_biguint(); - let b = b.to_biguint(); - - let quotient = a.clone() / b.clone(); - let remainder = a % b; - - // convert back to fields - let quotient = B::Field::from_biguint("ient).unwrap(); - let remainder = B::Field::from_biguint(&remainder).unwrap(); - - Ok(Some(Var::new( - vec![ConstOrCell::Const(quotient), ConstOrCell::Const(remainder)], - span, - ))) - } +struct DivmodFn {} + +impl Builtin for DivmodFn { + const SIGNATURE: &'static str = "divmod(dividend: Field, divisor: Field) -> [Field; 2]"; + + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + // we get two vars + let dividend_info = &vars[0]; + let divisor_info = &vars[1]; + + // retrieve the values + let dividend_var = ÷nd_info.var[0]; + let divisor_var = &divisor_info.var[0]; + + match (dividend_var, divisor_var) { + // two constants + (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { + // convert to bigints + let a = a.to_biguint(); + let b = b.to_biguint(); + + let quotient = a.clone() / b.clone(); + let remainder = a % b; + + // convert back to fields + let quotient = B::Field::from_biguint("ient).unwrap(); + let remainder = B::Field::from_biguint(&remainder).unwrap(); + + Ok(Some(Var::new( + vec![ConstOrCell::Const(quotient), ConstOrCell::Const(remainder)], + span, + ))) + } + + _ => { + let quotient = compiler + .backend + .new_internal_var(Value::Div(dividend_var.clone(), divisor_var.clone()), span); + let remainder = compiler + .backend + .new_internal_var(Value::Mod(dividend_var.clone(), divisor_var.clone()), span); - _ => { - let quotient = compiler - .backend - .new_internal_var(Value::Div(dividend_var.clone(), divisor_var.clone()), span); - let remainder = compiler - .backend - .new_internal_var(Value::Mod(dividend_var.clone(), divisor_var.clone()), span); - - Ok(Some(Var::new( - vec![ConstOrCell::Cell(quotient), ConstOrCell::Cell(remainder)], - span, - ))) + Ok(Some(Var::new( + vec![ConstOrCell::Cell(quotient), ConstOrCell::Cell(remainder)], + span, + ))) + } } } }