From bad3117bd5601c4deb0f90c2f0ac46b657af50ee Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Thu, 16 Nov 2023 13:49:36 -0800 Subject: [PATCH] relation: traits: Add logic gates to `ConstraintSystem` --- primitives/src/circuit/merkle_tree/mod.rs | 4 +- relation/src/gadgets/arithmetic.rs | 2 +- .../gadgets/ecc/emulated/short_weierstrass.rs | 2 +- relation/src/gadgets/ecc/glv.rs | 6 +- relation/src/gadgets/ecc/mod.rs | 4 +- relation/src/gadgets/emulated.rs | 2 +- relation/src/gadgets/logic.rs | 119 +----------------- relation/src/traits.rs | 114 ++++++++++++++++- 8 files changed, 128 insertions(+), 125 deletions(-) diff --git a/primitives/src/circuit/merkle_tree/mod.rs b/primitives/src/circuit/merkle_tree/mod.rs index b685195c0..3ae93b1f9 100644 --- a/primitives/src/circuit/merkle_tree/mod.rs +++ b/primitives/src/circuit/merkle_tree/mod.rs @@ -212,8 +212,8 @@ fn constrain_sibling_order( node_is_right: BoolVar, ) -> Result<[Variable; 3], CircuitError> { let one = F::one(); - let left_node = circuit.conditional_select(node_is_left, sib1, node)?; - let right_node = circuit.conditional_select(node_is_right, sib2, node)?; + let left_node = circuit.mux(node_is_left, node, sib1)?; + let right_node = circuit.mux(node_is_right, node, sib2)?; let left_plus_right = circuit.add(left_node, right_node)?; let mid_node = circuit.lc( &[node, sib1, sib2, left_plus_right], diff --git a/relation/src/gadgets/arithmetic.rs b/relation/src/gadgets/arithmetic.rs index ad593f2d6..96f33d3e5 100644 --- a/relation/src/gadgets/arithmetic.rs +++ b/relation/src/gadgets/arithmetic.rs @@ -765,7 +765,7 @@ mod test { let bit_true = circuit.create_boolean_variable(true)?; let x_0 = circuit.create_variable(F::from(23u32))?; let x_1 = circuit.create_variable(F::from(24u32))?; - circuit.conditional_select(bit_true, x_0, x_1)?; + circuit.mux(bit_true, x_1, x_0)?; // range gate let b = circuit.create_variable(F::from(1023u32))?; diff --git a/relation/src/gadgets/ecc/emulated/short_weierstrass.rs b/relation/src/gadgets/ecc/emulated/short_weierstrass.rs index ba53620ad..43b834f89 100644 --- a/relation/src/gadgets/ecc/emulated/short_weierstrass.rs +++ b/relation/src/gadgets/ecc/emulated/short_weierstrass.rs @@ -115,7 +115,7 @@ impl PlonkCircuit { ) -> Result, CircuitError> { let select_x = self.conditional_select_emulated(b, &p0.0, &p1.0)?; let select_y = self.conditional_select_emulated(b, &p0.1, &p1.1)?; - let select_infinity = BoolVar(self.conditional_select(b, p0.2 .0, p1.2 .0)?); + let select_infinity = BoolVar(self.mux(b, p1.2 .0, p0.2 .0)?); Ok(EmulatedSWPointVariable::( select_x, diff --git a/relation/src/gadgets/ecc/glv.rs b/relation/src/gadgets/ecc/glv.rs index 5e47bbf35..37e54ddd7 100644 --- a/relation/src/gadgets/ecc/glv.rs +++ b/relation/src/gadgets/ecc/glv.rs @@ -516,8 +516,7 @@ where }; // (f.3) either f.1 or f.2 is satisfied - let sat = - circuit.conditional_select(k2_sign_var, k2_is_neg_sat.into(), k2_is_pos_sat.into())?; + let sat = circuit.mux(k2_sign_var, k2_is_pos_sat.into(), k2_is_neg_sat.into())?; circuit.enforce_true(sat)?; // (g) tmp2 + lambda_2 * k2_sign * k2 + s2 = t * t_sign * r2 @@ -547,8 +546,7 @@ where }; // (g.3) either g.1 or g.2 is satisfied - let sat = - circuit.conditional_select(k2_sign_var, k2_is_neg_sat.into(), k2_is_pos_sat.into())?; + let sat = circuit.mux(k2_sign_var, k2_is_pos_sat.into(), k2_is_neg_sat.into())?; circuit.enforce_true(sat)?; // extract the output diff --git a/relation/src/gadgets/ecc/mod.rs b/relation/src/gadgets/ecc/mod.rs index 19089d6f9..f4fee8ef4 100644 --- a/relation/src/gadgets/ecc/mod.rs +++ b/relation/src/gadgets/ecc/mod.rs @@ -236,8 +236,8 @@ impl PlonkCircuit { self.check_point_var_bound(point0)?; self.check_point_var_bound(point1)?; - let selected_x = self.conditional_select(b, point0.0, point1.0)?; - let selected_y = self.conditional_select(b, point0.1, point1.1)?; + let selected_x = self.mux(b, point1.0, point0.0)?; + let selected_y = self.mux(b, point1.1, point0.1)?; Ok(PointVariable(selected_x, selected_y)) } diff --git a/relation/src/gadgets/emulated.rs b/relation/src/gadgets/emulated.rs index 10685acbc..54cacb5ca 100644 --- a/relation/src/gadgets/emulated.rs +++ b/relation/src/gadgets/emulated.rs @@ -602,7 +602,7 @@ impl PlonkCircuit { let mut vals = vec![]; for (&x_0, &x_1) in p0.0.iter().zip(p1.0.iter()) { - let selected = self.conditional_select(b, x_0, x_1)?; + let selected = self.mux(b, x_1, x_0)?; vals.push(selected); } diff --git a/relation/src/gadgets/logic.rs b/relation/src/gadgets/logic.rs index 3fbbc88bd..ed19cc77f 100644 --- a/relation/src/gadgets/logic.rs +++ b/relation/src/gadgets/logic.rs @@ -6,26 +6,11 @@ //! Logic related circuit implementations -use crate::{ - errors::CircuitError, - gates::{CondSelectGate, LogicOrGate, LogicOrOutputGate}, - traits::*, - BoolVar, PlonkCircuit, Variable, -}; +use crate::{errors::CircuitError, traits::*, BoolVar, PlonkCircuit, Variable}; use ark_ff::PrimeField; -use ark_std::{boxed::Box, string::ToString}; +use ark_std::string::ToString; impl PlonkCircuit { - /// Constrain that `a` is true or `b` is true. - /// Return error if variables are invalid. - pub fn logic_or_gate(&mut self, a: BoolVar, b: BoolVar) -> Result<(), CircuitError> { - self.check_var_bound(a.into())?; - self.check_var_bound(b.into())?; - let wire_vars = &[a.into(), b.into(), 0, 0, 0]; - self.insert_gate(wire_vars, Box::new(LogicOrGate))?; - Ok(()) - } - /// Obtain a bool variable representing whether two input variables are /// equal. Return error if variables are invalid. pub fn is_equal(&mut self, a: Variable, b: Variable) -> Result { @@ -74,96 +59,6 @@ impl PlonkCircuit { let one_var = self.one(); self.mul_gate(var, inv_var, one_var) } - - /// Obtain a variable representing the result of a logic negation gate. - /// Return the index of the variable. Return error if the input variable - /// is invalid. - pub fn logic_neg(&mut self, a: BoolVar) -> Result { - self.is_zero(a.into()) - } - - /// Obtain a variable representing the result of a logic AND gate. Return - /// the index of the variable. Return error if the input variables are - /// invalid. - pub fn logic_and(&mut self, a: BoolVar, b: BoolVar) -> Result { - let c = self - .create_boolean_variable_unchecked(self.witness(a.into())? * self.witness(b.into())?)?; - self.mul_gate(a.into(), b.into(), c.into())?; - Ok(c) - } - - /// Given a list of boolean variables, obtain a variable representing the - /// result of a logic AND gate. Return the index of the variable. Return - /// error if the input variables are invalid. - pub fn logic_and_all(&mut self, vars: &[BoolVar]) -> Result { - if vars.is_empty() { - return Err(CircuitError::ParameterError( - "logic_and_all: empty variable list".to_string(), - )); - } - let mut res = vars[0]; - for &var in vars.iter().skip(1) { - res = self.logic_and(res, var)?; - } - Ok(res) - } - - /// Obtain a variable representing the result of a logic OR gate. Return the - /// index of the variable. Return error if the input variables are - /// invalid. - pub fn logic_or(&mut self, a: BoolVar, b: BoolVar) -> Result { - self.check_var_bound(a.into())?; - self.check_var_bound(b.into())?; - - let a_val = self.witness(a.into())?; - let b_val = self.witness(b.into())?; - let c_val = a_val + b_val - a_val * b_val; - - let c = self.create_boolean_variable_unchecked(c_val)?; - let wire_vars = &[a.into(), b.into(), 0, 0, c.into()]; - self.insert_gate(wire_vars, Box::new(LogicOrOutputGate))?; - - Ok(c) - } - - /// Assuming values represented by `a` is boolean. - /// Constrain `a` is true - pub fn enforce_true(&mut self, a: Variable) -> Result<(), CircuitError> { - self.enforce_constant(a, F::one()) - } - - /// Assuming values represented by `a` is boolean. - /// Constrain `a` is false - pub fn enforce_false(&mut self, a: Variable) -> Result<(), CircuitError> { - self.enforce_constant(a, F::zero()) - } - - /// Obtain a variable that equals `x_0` if `b` is zero, or `x_1` if `b` is - /// one. Return error if variables are invalid. - pub fn conditional_select( - &mut self, - b: BoolVar, - x_0: Variable, - x_1: Variable, - ) -> Result { - self.check_var_bound(b.into())?; - self.check_var_bound(x_0)?; - self.check_var_bound(x_1)?; - - // y = x_bit - let y = if self.witness(b.into())? == F::zero() { - self.create_variable(self.witness(x_0)?)? - } else if self.witness(b.into())? == F::one() { - self.create_variable(self.witness(x_1)?)? - } else { - return Err(CircuitError::ParameterError( - "b in Conditional Selection gate is not a boolean variable".to_string(), - )); - }; - let wire_vars = [b.into(), x_0, b.into(), x_1, y]; - self.insert_gate(&wire_vars, Box::new(CondSelectGate))?; - Ok(y) - } } #[cfg(test)] @@ -389,8 +284,8 @@ mod test { let x_0 = circuit.create_variable(F::from(23u32))?; let x_1 = circuit.create_variable(F::from(24u32))?; - let select_true = circuit.conditional_select(bit_true, x_0, x_1)?; - let select_false = circuit.conditional_select(bit_false, x_0, x_1)?; + let select_true = circuit.mux(bit_true, x_1, x_0)?; + let select_false = circuit.mux(bit_false, x_1, x_0)?; assert_eq!(circuit.witness(select_true)?, circuit.witness(x_1)?); assert_eq!(circuit.witness(select_false)?, circuit.witness(x_0)?); @@ -400,9 +295,7 @@ mod test { *circuit.witness_mut(bit_false.into()) = F::one(); assert!(circuit.check_circuit_satisfiability(&[]).is_err()); // Check variable out of bound error. - assert!(circuit - .conditional_select(bit_false, circuit.num_vars(), x_1) - .is_err()); + assert!(circuit.mux(bit_false, x_1, circuit.num_vars()).is_err()); // build two fixed circuits with different variable assignments, checking that // the arithmetized extended permutation polynomial is variable @@ -422,7 +315,7 @@ mod test { let bit_var = circuit.create_boolean_variable(bit)?; let x_0_var = circuit.create_variable(x_0)?; let x_1_var = circuit.create_variable(x_1)?; - circuit.conditional_select(bit_var, x_0_var, x_1_var)?; + circuit.mux(bit_var, x_1_var, x_0_var)?; circuit.finalize_for_arithmetization()?; Ok(circuit) } diff --git a/relation/src/traits.rs b/relation/src/traits.rs index bf2ec9b82..8c346034e 100644 --- a/relation/src/traits.rs +++ b/relation/src/traits.rs @@ -14,7 +14,8 @@ use crate::{ errors::CircuitError, gates::{ AdditionGate, BoolGate, ConstantAdditionGate, ConstantGate, ConstantMultiplicationGate, - EqualityGate, Gate, LinCombGate, MulAddGate, MultiplicationGate, MuxGate, SubtractionGate, + EqualityGate, Gate, LinCombGate, LogicOrGate, LogicOrOutputGate, MulAddGate, + MultiplicationGate, MuxGate, SubtractionGate, }, next_multiple, BoolVar, SortedLookupVecAndPolys, Variable, }; @@ -509,6 +510,117 @@ pub trait Circuit { // | Logic Gates | // --------------- + /// Constrain that `a` is true or `b` is true. + /// Return error if variables are invalid. + fn logic_or_gate(&mut self, a: BoolVar, b: BoolVar) -> Result<(), CircuitError> { + self.check_var(a.into())?; + self.check_var(b.into())?; + + let wire_vars = &[a.into(), b.into(), 0, 0, 0]; + self.insert_gate(wire_vars, Box::new(LogicOrGate))?; + + Ok(()) + } + + /// Obtain a variable representing the result of a logic OR gate. Return the + /// index of the variable. Return error if the input variables are + /// invalid. + fn logic_or(&mut self, a: BoolVar, b: BoolVar) -> Result { + self.check_var(a.into())?; + self.check_var(b.into())?; + + let a_val = self.witness(a.into())?; + let b_val = self.witness(b.into())?; + let c_val = a_val.clone() + b_val.clone() - a_val * b_val; + + let c = self.create_variable(c_val)?; + let wire_vars = &[a.into(), b.into(), 0, 0, c]; + self.insert_gate(wire_vars, Box::new(LogicOrOutputGate))?; + + // We do not need to constrain the output to be boolean as the inputs already + // are, and the range of this gate is {0, 1} + Ok(BoolVar(c)) + } + + /// Given a list of boolean variables, obtain a variable representing the + /// result of a logic OR gate. Return the index of the variable. Return + /// error if the input variables are invalid. + fn logic_or_all(&mut self, vars: &[BoolVar]) -> Result { + if vars.is_empty() { + return Err(CircuitError::ParameterError( + "logic_or_all: empty variable list".to_string(), + )); + } + + let mut res = vars[0]; + for var in vars.iter().skip(1) { + res = self.logic_or(res, *var)?; + } + + Ok(res) + } + + /// Constrain that `a` is true and `b` is true + fn logic_and_gate(&mut self, a: BoolVar, b: BoolVar) -> Result<(), CircuitError> { + self.mul_gate(a.into(), b.into(), self.one()) + } + + /// Obtain a variable representing the result of a logic AND gate. Return + /// the index of the variable. Return error if the input variables are + /// invalid. + fn logic_and(&mut self, a: BoolVar, b: BoolVar) -> Result { + let c = self.mul(a.into(), b.into())?; + + // We do not need to constrain the output to be boolean as the inputs already + // are, and the range of this gate is {0, 1} + Ok(BoolVar(c)) + } + + /// Given a list of boolean variables, obtain a variable representing the + /// result of a logic AND gate. Return the index of the variable. Return + /// error if the input variables are invalid. + fn logic_and_all(&mut self, vars: &[BoolVar]) -> Result { + if vars.is_empty() { + return Err(CircuitError::ParameterError( + "logic_and_all: empty variable list".to_string(), + )); + } + + let mut res = vars[0]; + for &var in vars.iter().skip(1) { + res = self.logic_and(res, var)?; + } + + Ok(res) + } + + /// Obtain a variable representing the result of a logic negation gate. + /// Return the index of the variable. Return error if the input variable + /// is invalid. + fn logic_neg(&mut self, a: BoolVar) -> Result { + self.check_var(a.into())?; + + // 1 - a + let one = self.one(); + let res = self.add_with_coeffs(one, a.into(), &F::one(), &-F::one())?; + + // We do not need to constrain the output to be boolean as the inputs already + // are, and the range of this gate is {0, 1} + Ok(BoolVar(res)) + } + + /// Assuming values represented by `a` is boolean + /// Constrain `a` is true + fn enforce_true(&mut self, a: Variable) -> Result<(), CircuitError> { + self.enforce_constant(a, F::one()) + } + + /// Assuming values represented by `a` is boolean + /// Constrain `a` is false + fn enforce_false(&mut self, a: Variable) -> Result<(), CircuitError> { + self.enforce_constant(a, F::zero()) + } + /// Constrains variable `c` to be `if sel { a } else { b }` fn mux_gate( &mut self,