diff --git a/plonk/src/multiprover/proof_system/constraint_system.rs b/plonk/src/multiprover/proof_system/constraint_system.rs index 8b8930c8b..c36dbdcd2 100644 --- a/plonk/src/multiprover/proof_system/constraint_system.rs +++ b/plonk/src/multiprover/proof_system/constraint_system.rs @@ -130,11 +130,13 @@ where // do so by adding proof-linking gates at specific indices in the circuit's arithmetization // of the form a(x) * 0 = 0, where a(x) encodes the witness to be linked. We can then // invoke a polynomial subprotocol to prove that the a(x) polynomial is the same between - // proofs. The following fields are used to track membership in groups and placement of groups - // in the arithmetization + // proofs. The following fields are used to track membership in groups and placement of + // groups in the arithmetization /// The proof-linking group layouts for the circuit. Maps a group ID to the /// indices of the witness contained in the group - link_groups: HashMap<&'static str, Vec>, + link_groups: HashMap>, + /// The offsets at which to place the link groups in the arithmetization + link_group_offsets: HashMap, /// The underlying MPC fabric that this circuit is allocated within fabric: MpcFabric, @@ -164,6 +166,7 @@ where // This is later updated eval_domain: Radix2EvaluationDomain::new(1 /* num_coeffs */).unwrap(), link_groups: HashMap::new(), + link_group_offsets: HashMap::new(), fabric, }; @@ -210,7 +213,7 @@ where for link_group in link_groups { self.link_groups - .get_mut(link_group.id) + .get_mut(&link_group.id) .ok_or(CircuitError::LinkGroupNotFound(link_group.id.to_string()))? .push(var); } @@ -307,9 +310,8 @@ where pub_input: &AuthenticatedScalarResult, ) -> ScalarResult { // Compute wire values - let w_vals = (0..=GATE_WIDTH) - .map(|i| &self.witness[self.wire_variables[i][gate_id]]) - .collect_vec(); + let w_vals = + (0..=GATE_WIDTH).map(|i| &self.witness[self.wire_variables[i][gate_id]]).collect_vec(); // Compute selector values macro_rules! as_scalars { @@ -387,11 +389,8 @@ where // Compute the mapping from variables to wires. let mut variable_wires_map = vec![vec![]; m]; - for (gate_wire_id, variables) in self - .wire_variables - .iter() - .take(self.num_wire_types()) - .enumerate() + for (gate_wire_id, variables) in + self.wire_variables.iter().take(self.num_wire_types()).enumerate() { for (gate_id, &var) in variables.iter().enumerate() { variable_wires_map[var].push((gate_wire_id, gate_id)); @@ -574,10 +573,7 @@ where ) -> Result<(), CircuitError> { let n = public_input.len(); if n != self.num_inputs() { - return Err(CircuitError::PubInputLenMismatch( - n, - self.pub_input_gate_ids.len(), - )); + return Err(CircuitError::PubInputLenMismatch(n, self.pub_input_gate_ids.len())); } let mut gate_results = Vec::new(); @@ -605,10 +601,7 @@ where if res == Scalar::zero() { Ok(()) } else { - Err(CircuitError::GateCheckFailure( - idx, - "gate check failed".to_string(), - )) + Err(CircuitError::GateCheckFailure(idx, "gate check failed".to_string())) } }) .collect::, CircuitError>>() @@ -641,9 +634,8 @@ where link_groups: &[LinkGroup], ) -> Result { let authenticated_val = self.fabric.one_authenticated() * Scalar::new(val); - let var = self.create_variable(authenticated_val)?; + let var = self.create_variable_with_link_groups(authenticated_val, link_groups)?; self.enforce_constant(var, val)?; - self.add_to_link_groups(var, link_groups)?; Ok(var) } @@ -673,9 +665,12 @@ where Ok(()) } - // TODO: properly handle offsets - fn create_link_group(&mut self, id: &'static str, _offset: Option) -> LinkGroup { - self.link_groups.insert(id, Vec::new()); + fn create_link_group(&mut self, id: String, offset: Option) -> LinkGroup { + self.link_groups.insert(id.clone(), Vec::new()); + if let Some(offset) = offset { + self.link_group_offsets.insert(id.clone(), offset); + } + LinkGroup { id } } @@ -686,9 +681,8 @@ where ) -> Result<(), CircuitError> { self.check_finalize_flag(false)?; - for (wire_var, wire_variable) in wire_vars - .iter() - .zip(self.wire_variables.iter_mut().take(GATE_WIDTH + 1)) + for (wire_var, wire_variable) in + wire_vars.iter().zip(self.wire_variables.iter_mut().take(GATE_WIDTH + 1)) { wire_variable.push(*wire_var) } @@ -855,9 +849,8 @@ where let mut numerator_terms = Vec::with_capacity(self.num_wire_types()); let mut denominator_terms = Vec::with_capacity(self.num_wire_types()); for i in 0..self.num_wire_types() { - let wire_values = (0..(n - 1)) - .map(|j| self.witness[self.wire_variable(i, j)].clone()) - .collect_vec(); + let wire_values = + (0..(n - 1)).map(|j| self.witness[self.wire_variable(i, j)].clone()).collect_vec(); let id_perm_values = (0..(n - 1)) .map(|j| Scalar::new(self.extended_id_permutation[i * n + j])) .collect_vec(); @@ -914,10 +907,8 @@ where .iter() .take(self.num_wire_types()) .map(|wire_vars| { - let wire_vec: Vec> = wire_vars - .iter() - .map(|&var| witness[var].clone()) - .collect_vec(); + let wire_vec: Vec> = + wire_vars.iter().map(|&var| witness[var].clone()).collect_vec(); let coeffs = AuthenticatedScalarResult::ifft::< Radix2EvaluationDomain, diff --git a/plonk/src/multiprover/proof_system/prover.rs b/plonk/src/multiprover/proof_system/prover.rs index dca4ab56c..0d11d8001 100644 --- a/plonk/src/multiprover/proof_system/prover.rs +++ b/plonk/src/multiprover/proof_system/prover.rs @@ -36,10 +36,8 @@ use super::{MpcArithmetization, MpcOracles}; /// A type alias for a bundle of commitments and polynomials /// TODO: Remove this lint allowance #[allow(type_alias_bounds)] -type MpcCommitmentsAndPolys = ( - Vec>, - Vec>, -); +type MpcCommitmentsAndPolys = + (Vec>, Vec>); /// A Plonk IOP prover over a secret shared algebra /// TODO: Remove this lint allowance @@ -85,11 +83,7 @@ impl MpcProver { ) .ok_or(PlonkError::DomainCreationError)?; - Ok(Self { - domain, - quot_domain, - fabric, - }) + Ok(Self { domain, quot_domain, fabric }) } /// Round 1: @@ -172,11 +166,8 @@ impl MpcProver { online_oracles: &MpcOracles, num_wire_types: usize, ) -> MpcProofEvaluations { - let wires_evals: Vec> = online_oracles - .wire_polys - .iter() - .map(|poly| poly.eval(&challenges.zeta)) - .collect(); + let wires_evals: Vec> = + online_oracles.wire_polys.iter().map(|poly| poly.eval(&challenges.zeta)).collect(); let wire_sigma_evals: Vec> = pk .sigmas @@ -370,23 +361,14 @@ impl MpcProver { .collect(); // The coset we use to compute the quotient polynomial - let coset = self - .quot_domain - .get_coset(E::ScalarField::GENERATOR) - .unwrap(); + let coset = self.quot_domain.get_coset(E::ScalarField::GENERATOR).unwrap(); // Compute evaluations of the selectors, permutations, and wiring polynomials - let selectors_coset_fft: Vec> = pk - .selectors - .iter() - .map(|poly| coset.fft(poly.coeffs())) - .collect(); + let selectors_coset_fft: Vec> = + pk.selectors.iter().map(|poly| coset.fft(poly.coeffs())).collect(); - let sigmas_coset_fft: Vec> = pk - .sigmas - .iter() - .map(|poly| coset.fft(poly.coeffs())) - .collect(); + let sigmas_coset_fft: Vec> = + pk.sigmas.iter().map(|poly| coset.fft(poly.coeffs())).collect(); let wire_polys_coset_fft: Vec>> = online_oracles .wire_polys @@ -504,10 +486,7 @@ impl MpcProver { let mut nonzero_wires: Vec>> = vec![vec![]; m]; for (i, selector) in selectors.iter().enumerate() { if !selector.is_zero() { - wires - .iter() - .enumerate() - .for_each(|(j, w)| nonzero_wires[j].push(w[i].clone())); + wires.iter().enumerate().for_each(|(j, w)| nonzero_wires[j].push(w[i].clone())); nonzero_sel.push(Scalar::new(*selector)); nonzero_indices.push(i); } @@ -593,10 +572,7 @@ impl MpcProver { prod_perm_poly_coset_evals: &[AuthenticatedScalarResult], challenges: &MpcChallenges, sigmas_coset_fft: &[Vec], - ) -> ( - Vec>, - Vec>, - ) { + ) -> (Vec>, Vec>) { let n = pk.domain_size(); let m = self.quot_domain.size(); @@ -607,9 +583,8 @@ impl MpcProver { // Construct the evaluations shifted by the coset generators let coset_generators = pk.k().iter().copied().collect_vec(); - let eval_points = (0..m) - .map(|i| self.quot_domain.element(i) * E::ScalarField::GENERATOR) - .collect_vec(); + let eval_points = + (0..m).map(|i| self.quot_domain.element(i) * E::ScalarField::GENERATOR).collect_vec(); let mut all_evals = Vec::with_capacity(num_wire_types); for generator in coset_generators.iter() { @@ -720,11 +695,8 @@ impl MpcProver { // chunks of degree n + 1 contiguous coefficients let mut split_quot_polys: Vec> = (0..num_wire_types) .map(|i| { - let end = if i < num_wire_types - 1 { - (i + 1) * (n + 2) - } else { - quot_poly.degree() + 1 - }; + let end = + if i < num_wire_types - 1 { (i + 1) * (n + 2) } else { quot_poly.degree() + 1 }; // Degree-(n+1) polynomial has n + 2 coefficients. AuthenticatedDensePoly::from_coeffs(quot_poly.coeffs[i * (n + 2)..end].to_vec()) @@ -738,22 +710,17 @@ impl MpcProver { // with t_lowest_i(X) = t_lowest_i(X) - 0 + b_now_i * X^(n+2) // and t_highest_i(X) = t_highest_i(X) - b_last_i let mut last_randomizer = self.fabric.zero_authenticated(); - let mut randomizers = self - .fabric - .random_shared_scalars_authenticated(num_wire_types - 1); + let mut randomizers = self.fabric.random_shared_scalars_authenticated(num_wire_types - 1); - split_quot_polys - .iter_mut() - .take(num_wire_types - 1) - .for_each(|poly| { - poly.coeffs[0] = &poly.coeffs[0] - &last_randomizer; - assert_eq!(poly.degree(), n + 1); + split_quot_polys.iter_mut().take(num_wire_types - 1).for_each(|poly| { + poly.coeffs[0] = &poly.coeffs[0] - &last_randomizer; + assert_eq!(poly.degree(), n + 1); - let next_randomizer = randomizers.pop().unwrap(); - poly.coeffs.push(next_randomizer.clone()); + let next_randomizer = randomizers.pop().unwrap(); + poly.coeffs.push(next_randomizer.clone()); - last_randomizer = next_randomizer; - }); + last_randomizer = next_randomizer; + }); // Mask the highest splitting poly split_quot_polys[num_wire_types - 1].coeffs[0] = @@ -894,9 +861,7 @@ pub fn element_wise_product( // If we choose to view the vectors as tiling the columns of matrices, each step // in this fold replaces the first and second columns with their // element-wise product - vectors[2..].iter().fold(initial, |acc, vec| { - AuthenticatedScalarResult::batch_mul(&acc, vec) - }) + vectors[2..].iter().fold(initial, |acc, vec| AuthenticatedScalarResult::batch_mul(&acc, vec)) } /// Take the element-wise sum of a set of vectors @@ -918,9 +883,7 @@ fn element_wise_sum( // If we choose to view the vectors as tiling the columns of matrices, each step // in this fold replaces the first and second columns with their // element-wise sum - vectors[2..].iter().fold(initial, |acc, vec| { - AuthenticatedScalarResult::batch_add(&acc, vec) - }) + vectors[2..].iter().fold(initial, |acc, vec| AuthenticatedScalarResult::batch_add(&acc, vec)) } /// Evaluate a public polynomial on a result in the MPC fabric @@ -1055,10 +1018,8 @@ pub(crate) mod test { return AuthenticatedDensePoly::from_coeffs(vec![fabric.zero_authenticated()]); } - let coeffs = fabric.batch_share_scalar( - poly.coeffs.iter().cloned().map(Scalar::new).collect(), - PARTY0, - ); + let coeffs = fabric + .batch_share_scalar(poly.coeffs.iter().cloned().map(Scalar::new).collect(), PARTY0); AuthenticatedDensePoly::from_coeffs(coeffs) } @@ -1280,7 +1241,7 @@ pub(crate) mod test { /// Run the fifth round of a single-prover circuit /// /// Returns the commitments to the opening and shifted opening polynomials - fn run_fifth_round(params: &mut TestParams) -> (Commitment, Commitment) { + fn run_fifth_round(params: &TestParams) -> (Commitment, Commitment) { params .prover .compute_opening_proofs( @@ -1330,10 +1291,7 @@ pub(crate) mod test { .collect::>(); let pub_poly = pub_poly.open_authenticated(); - ( - (wire_comms_open.await, wire_polys_open.await), - pub_poly.await.unwrap(), - ) + ((wire_comms_open.await, wire_polys_open.await), pub_poly.await.unwrap()) } }) .await; @@ -1372,10 +1330,7 @@ pub(crate) mod test { .unwrap(); // Open the results - ( - perm_commit.open_authenticated().await.unwrap(), - perm_poly.open().await, - ) + (perm_commit.open_authenticated().await.unwrap(), perm_poly.open().await) } }) .await; @@ -1508,7 +1463,7 @@ pub(crate) mod test { run_second_round(true /* mask */, &mut params); run_third_round(false /* mask */, &mut params); run_fourth_round(&mut params); - let (expected_open, expected_shift) = run_fifth_round(&mut params); + let (expected_open, expected_shift) = run_fifth_round(¶ms); // Compute the result in an MPC let ((open, shifted_open), _) = execute_mock_mpc(|fabric| { diff --git a/relation/src/constraint_system.rs b/relation/src/constraint_system.rs index 94546ab1f..5cfbba402 100644 --- a/relation/src/constraint_system.rs +++ b/relation/src/constraint_system.rs @@ -18,8 +18,9 @@ use ark_poly::{ domain::Radix2EvaluationDomain, univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, }; -use ark_std::{boxed::Box, cmp::max, format, string::ToString, vec, vec::Vec}; +use ark_std::{boxed::Box, cmp::max, format, iterable::Iterable, string::ToString, vec, vec::Vec}; use hashbrown::{HashMap, HashSet}; +use itertools::Itertools; use jf_utils::par_utils::parallelizable_slice_iter; #[cfg(feature = "parallel")] use rayon::prelude::*; @@ -72,6 +73,17 @@ pub enum MergeableCircuitType { TypeB, } +/// The layout of a circuit +#[derive(Clone, Debug)] +pub struct CircuitLayout { + /// The number of public inputs to the circuit + pub n_inputs: usize, + /// The number of gates in the circuit + pub n_gates: usize, + /// The offsets of the proof linking groups in the circuit + pub link_group_offsets: HashMap, +} + /// The wire type identifier for range gates. const RANGE_WIRE_ID: usize = 5; /// The wire type identifier for the key index in a lookup gate @@ -96,21 +108,13 @@ struct PlonkParams { impl PlonkParams { fn init(plonk_type: PlonkType, range_bit_len: Option) -> Result { if plonk_type == PlonkType::TurboPlonk { - return Ok(Self { - plonk_type, - range_bit_len: None, - }); + return Ok(Self { plonk_type, range_bit_len: None }); } if range_bit_len.is_none() { - return Err(ParameterError( - "range bit len cannot be none for UltraPlonk".to_string(), - )); + return Err(ParameterError("range bit len cannot be none for UltraPlonk".to_string())); } - Ok(Self { - plonk_type, - range_bit_len, - }) + Ok(Self { plonk_type, range_bit_len }) } } @@ -166,11 +170,13 @@ where // do so by adding proof-linking gates at specific indices in the circuit's arithmetization // of the form a(x) * 0 = 0, where a(x) encodes the witness to be linked. We can then // invoke a polynomial subprotocol to prove that the a(x) polynomial is the same between - // proofs. The following fields are used to track membership in groups and placement of groups - // in the arithmetization + // proofs on some subdomain. The following fields are used to track membership in groups and + // placement of groups in the arithmetization /// The proof-linking group layouts for the circuit. Maps a group ID to the /// indices of the witness contained in the group - link_groups: HashMap<&'static str, Vec>, + link_groups: HashMap>, + /// The offset of each group in the arithmetization + link_group_offsets: HashMap, /// The Plonk parameters. plonk_params: PlonkParams, @@ -214,6 +220,7 @@ impl PlonkCircuit { }, eval_domain: Radix2EvaluationDomain::new(1).unwrap(), link_groups: HashMap::new(), + link_group_offsets: HashMap::new(), plonk_params, num_table_elems: 0, table_gate_ids: vec![], @@ -288,10 +295,10 @@ impl PlonkCircuit { self.check_finalize_flag(false)?; self.check_var_bound(var)?; - for link_group in link_groups { + for group in link_groups { self.link_groups - .get_mut(link_group.id) - .ok_or(LinkGroupNotFound(link_group.id.to_string()))? + .get_mut(&group.id) + .ok_or_else(|| LinkGroupNotFound(group.id.to_string()))? .push(var); } @@ -388,10 +395,7 @@ impl Circuit for PlonkCircuit { fn check_circuit_satisfiability(&self, pub_input: &[F]) -> Result<(), CircuitError> { if pub_input.len() != self.num_inputs() { - return Err(PubInputLenMismatch( - pub_input.len(), - self.pub_input_gate_ids.len(), - )); + return Err(PubInputLenMismatch(pub_input.len(), self.pub_input_gate_ids.len())); } // Check public I/O gates for (i, gate_id) in self.pub_input_gate_ids.iter().enumerate() { @@ -469,9 +473,8 @@ impl Circuit for PlonkCircuit { val: F, link_groups: &[LinkGroup], ) -> Result { - let var = self.create_variable(val)?; + let var = self.create_variable_with_link_groups(val, link_groups)?; self.enforce_constant(var, val)?; - self.add_to_link_groups(var, link_groups)?; Ok(var) } @@ -502,9 +505,11 @@ impl Circuit for PlonkCircuit { Ok(()) } - // TODO: Add offset support - fn create_link_group(&mut self, id: &'static str, _offset: Option) -> LinkGroup { - self.link_groups.insert(id, Vec::new()); + fn create_link_group(&mut self, id: String, offset: Option) -> LinkGroup { + self.link_groups.insert(id.clone(), Vec::new()); + if let Some(offset) = offset { + self.link_group_offsets.insert(id.clone(), offset); + } LinkGroup { id } } @@ -520,11 +525,10 @@ impl Circuit for PlonkCircuit { ) -> Result<(), CircuitError> { self.check_finalize_flag(false)?; - for (wire_var, wire_variable) in wire_vars - .iter() - .zip(self.wire_variables.iter_mut().take(GATE_WIDTH + 1)) + for (wire_var, wire_variable) in + wire_vars.iter().zip(self.wire_variables.iter_mut().take(GATE_WIDTH + 1)) { - wire_variable.push(*wire_var) + wire_variable.push(wire_var) } self.gates.push(gate); @@ -570,7 +574,8 @@ impl PlonkCircuit { /// Re-arrange the order of the gates so that /// 1. io gates are in the front. - /// 2. variable table lookup gate are at the rear so that they do not affect + /// 2. proof linking gates follow + /// 3. variable table lookup gate are at the rear so that they do not affect /// the range gates when merging the lookup tables. /// /// Remember to pad gates before calling the method. @@ -588,6 +593,11 @@ impl PlonkCircuit { *io_gate_id = gate_id; } } + + // Insert proof linking gates into the circuit + let layout = self.gen_circuit_layout()?; + self.insert_link_gates(&layout); + if self.support_lookup() { // move lookup gates to the rear, the relative order of the lookup gates // should not change @@ -610,6 +620,89 @@ impl PlonkCircuit { } Ok(()) } + + /// Place the proof linking gates into the circuit + pub fn gen_circuit_layout(&self) -> Result { + // 1. Place the proof linking groups with specific offsets into the circuit + let mut sorted_offset_pairs = self.link_group_offsets.clone().into_iter().collect_vec(); + sorted_offset_pairs.sort_by_key(|(_, offset)| *offset); + + // 2. Place the rest of the gates into the circuit + // Sort the keys so they appear stably in the trace + for group_id in + self.link_groups.keys().sorted().filter(|id| !self.link_group_offsets.contains_key(*id)) + { + self.place_group(group_id.clone(), &mut sorted_offset_pairs); + } + + // 3. Create a layout and validate it + let link_group_offsets = sorted_offset_pairs.into_iter().collect(); + let layout = CircuitLayout { + n_inputs: self.num_inputs(), + n_gates: self.num_gates(), + link_group_offsets, + }; + if !self.validate_layout(&layout) { + return Err(CircuitError::Layout( + "Invalid circuit layout, please check the link group offsets".to_string(), + )); + } + + Ok(layout) + } + + /// Place a group into a list of already allocated groups + /// + /// Mutates `placed_groups` to insert the new group + fn place_group(&self, group_id: String, placed_groups: &mut Vec<(String, usize)>) { + let group_size = self.link_groups.get(&group_id).unwrap().len(); + let mut offset = self.num_inputs(); // Link gates being after i/o gates + let mut next_idx = 0; // Index into the placed gates + + loop { + // Determine the size of the candidate placement between the current offset and + // the next group that has already been placed + let default = (String::from(""), usize::MAX); + let (next_group_id, next_group_boundary) = + placed_groups.get(next_idx).unwrap_or(&default); + let candidate_size = next_group_boundary - offset; + + if candidate_size >= group_size { + // We have found a placement for the group, insert the group into the + // `placed_groups` + placed_groups.insert(next_idx, (group_id, offset)); + return; + } + + // Move to the next candidate range + let next_group_size = self.link_groups.get(next_group_id).unwrap().len(); + offset = *next_group_boundary + next_group_size; + next_idx += 1; + } + } + + /// Validate a circuit layout against the circuit's arithmetization + fn validate_layout(&self, layout: &CircuitLayout) -> bool { + // Check that none of the link groups are overlapping + let sorted = layout.link_group_offsets.iter().sorted_by_key(|(_, off)| *off).collect_vec(); + for window in sorted.windows(2 /* size */) { + let (id, off1) = window[0]; + let (_, off2) = window[1]; + + let group_len = self.link_groups.get(id).unwrap().len(); + if off1 + group_len > *off2 { + return false; + } + } + + true + } + + /// Insert proof linking gates into the circuit + fn insert_link_gates(&self, _layout: &CircuitLayout) { + todo!() + } + // use downcast to check whether a gate is of IoGate type fn is_io_gate(&self, gate_id: GateId) -> bool { self.gates[gate_id].as_any().is::() @@ -640,9 +733,8 @@ impl PlonkCircuit { fn check_gate(&self, gate_id: Variable, pub_input: &F) -> Result<(), CircuitError> { // Compute wire values - let w_vals: Vec = (0..GATE_WIDTH + 1) - .map(|i| self.witness[self.wire_variables[i][gate_id]]) - .collect(); + let w_vals: Vec = + (0..GATE_WIDTH + 1).map(|i| self.witness[self.wire_variables[i][gate_id]]).collect(); // Compute selector values. let q_lc: [F; GATE_WIDTH] = self.gates[gate_id].q_lc(); let q_mul: [F; N_MUL_SELECTORS] = self.gates[gate_id].q_mul(); @@ -698,11 +790,8 @@ impl PlonkCircuit { // slightly faster than using a `HashMap>` as we // avoid any constant overhead from the hashmap read/write. let mut variable_wires_map = vec![vec![]; m]; - for (gate_wire_id, variables) in self - .wire_variables - .iter() - .take(self.num_wire_types()) - .enumerate() + for (gate_wire_id, variables) in + self.wire_variables.iter().take(self.num_wire_types()).enumerate() { for (gate_id, &var) in variables.iter().enumerate() { variable_wires_map[var].push((gate_wire_id, gate_id)); @@ -996,14 +1085,8 @@ impl PlonkCircuit { /// The method only supports TurboPlonk circuits. #[allow(dead_code)] pub fn merge(&self, other: &Self) -> Result { - assert!( - self.link_groups.is_empty(), - "proof linking not supported for merged circuits" - ); - assert!( - other.link_groups.is_empty(), - "proof linking not supported for merged circuits" - ); + assert!(self.link_groups.is_empty(), "proof linking not supported for merged circuits"); + assert!(other.link_groups.is_empty(), "proof linking not supported for merged circuits"); self.check_finalize_flag(true)?; other.check_finalize_flag(true)?; @@ -1029,22 +1112,15 @@ impl PlonkCircuit { ))); } if self.pub_input_gate_ids[0] != 0 { - return Err(ParameterError( - "the first circuit is not type A".to_string(), - )); + return Err(ParameterError("the first circuit is not type A".to_string())); } if other.pub_input_gate_ids[0] != other.eval_domain_size()? - 1 { - return Err(ParameterError( - "the second circuit is not type B".to_string(), - )); + return Err(ParameterError("the second circuit is not type B".to_string())); } let num_vars = self.num_vars + other.num_vars; let witness: Vec = [self.witness.as_slice(), other.witness.as_slice()].concat(); - let pub_input_gate_ids: Vec = [ - self.pub_input_gate_ids.as_slice(), - other.pub_input_gate_ids.as_slice(), - ] - .concat(); + let pub_input_gate_ids: Vec = + [self.pub_input_gate_ids.as_slice(), other.pub_input_gate_ids.as_slice()].concat(); // merge gates and wire variables // the first circuit occupies the first n gates, the second circuit @@ -1054,21 +1130,13 @@ impl PlonkCircuit { let mut wire_variables = [vec![], vec![], vec![], vec![], vec![], vec![]]; for (j, gate) in self.gates.iter().take(n).enumerate() { gates.push((*gate).clone()); - for (i, wire_vars) in wire_variables - .iter_mut() - .enumerate() - .take(self.num_wire_types) - { + for (i, wire_vars) in wire_variables.iter_mut().enumerate().take(self.num_wire_types) { wire_vars.push(self.wire_variable(i, j)); } } for (j, gate) in other.gates.iter().skip(n).enumerate() { gates.push((*gate).clone()); - for (i, wire_vars) in wire_variables - .iter_mut() - .enumerate() - .take(self.num_wire_types) - { + for (i, wire_vars) in wire_variables.iter_mut().enumerate().take(self.num_wire_types) { wire_vars.push(other.wire_variable(i, n + j) + self.num_vars); } } @@ -1094,6 +1162,7 @@ impl PlonkCircuit { eval_domain: self.eval_domain, // `link_groups` must be empty for both proofs link_groups: HashMap::new(), + link_group_offsets: HashMap::new(), plonk_params: self.plonk_params, num_table_elems: 0, table_gate_ids: vec![], @@ -1219,36 +1288,28 @@ where fn compute_range_table_polynomial(&self) -> Result, CircuitError> { let range_table = self.compute_range_table()?; let domain = &self.eval_domain; - Ok(DensePolynomial::from_coefficients_vec( - domain.ifft(&range_table), - )) + Ok(DensePolynomial::from_coefficients_vec(domain.ifft(&range_table))) } fn compute_key_table_polynomial(&self) -> Result, CircuitError> { self.check_plonk_type(PlonkType::UltraPlonk)?; self.check_finalize_flag(true)?; let domain = &self.eval_domain; - Ok(DensePolynomial::from_coefficients_vec( - domain.ifft(&self.table_key_vec()), - )) + Ok(DensePolynomial::from_coefficients_vec(domain.ifft(&self.table_key_vec()))) } fn compute_table_dom_sep_polynomial(&self) -> Result, CircuitError> { self.check_plonk_type(PlonkType::UltraPlonk)?; self.check_finalize_flag(true)?; let domain = &self.eval_domain; - Ok(DensePolynomial::from_coefficients_vec( - domain.ifft(&self.table_dom_sep_vec()), - )) + Ok(DensePolynomial::from_coefficients_vec(domain.ifft(&self.table_dom_sep_vec()))) } fn compute_q_dom_sep_polynomial(&self) -> Result, CircuitError> { self.check_plonk_type(PlonkType::UltraPlonk)?; self.check_finalize_flag(true)?; let domain = &self.eval_domain; - Ok(DensePolynomial::from_coefficients_vec( - domain.ifft(&self.q_dom_sep()), - )) + Ok(DensePolynomial::from_coefficients_vec(domain.ifft(&self.q_dom_sep()))) } fn compute_merged_lookup_table(&self, tau: F) -> Result, CircuitError> { @@ -1296,9 +1357,7 @@ where )); } if 2 * n - 1 != sorted_vec.len() { - return Err(ParameterError( - "The sorted vector has wrong length".to_string(), - )); + return Err(ParameterError("The sorted vector has wrong length".to_string())); } let mut product_vec = vec![F::one()]; @@ -1652,9 +1711,7 @@ pub(crate) mod test { assert!(circuit.check_circuit_satisfiability(&[]).is_err()); // Check variable out of bound error. - assert!(circuit - .enforce_constant(circuit.num_vars(), F::from(0u32)) - .is_err()); + assert!(circuit.enforce_constant(circuit.num_vars(), F::from(0u32)).is_err()); Ok(()) } @@ -1676,34 +1733,20 @@ pub(crate) mod test { circuit.set_variable_public(b)?; // Different valid public inputs should all pass the circuit check. - assert!(circuit - .check_circuit_satisfiability(&[F::from(1u32), F::from(0u32)]) - .is_ok()); + assert!(circuit.check_circuit_satisfiability(&[F::from(1u32), F::from(0u32)]).is_ok()); *circuit.witness_mut(a) = F::from(0u32); - assert!(circuit - .check_circuit_satisfiability(&[F::from(0u32), F::from(0u32)]) - .is_ok()); + assert!(circuit.check_circuit_satisfiability(&[F::from(0u32), F::from(0u32)]).is_ok()); *circuit.witness_mut(b) = F::from(1u32); - assert!(circuit - .check_circuit_satisfiability(&[F::from(0u32), F::from(1u32)]) - .is_ok()); + assert!(circuit.check_circuit_satisfiability(&[F::from(0u32), F::from(1u32)]).is_ok()); // Invalid public inputs should fail the circuit check. - assert!(circuit - .check_circuit_satisfiability(&[F::from(2u32), F::from(1u32)]) - .is_err()); + assert!(circuit.check_circuit_satisfiability(&[F::from(2u32), F::from(1u32)]).is_err()); *circuit.witness_mut(a) = F::from(2u32); - assert!(circuit - .check_circuit_satisfiability(&[F::from(2u32), F::from(1u32)]) - .is_err()); + assert!(circuit.check_circuit_satisfiability(&[F::from(2u32), F::from(1u32)]).is_err()); *circuit.witness_mut(a) = F::from(0u32); - assert!(circuit - .check_circuit_satisfiability(&[F::from(0u32), F::from(2u32)]) - .is_err()); + assert!(circuit.check_circuit_satisfiability(&[F::from(0u32), F::from(2u32)]).is_err()); *circuit.witness_mut(b) = F::from(2u32); - assert!(circuit - .check_circuit_satisfiability(&[F::from(0u32), F::from(2u32)]) - .is_err()); + assert!(circuit.check_circuit_satisfiability(&[F::from(0u32), F::from(2u32)]).is_err()); Ok(()) } @@ -1874,10 +1917,7 @@ pub(crate) mod test { let group_elems: Vec = domain.elements().collect(); (0..circuit.num_wire_types).for_each(|i| { (0..n).for_each(|j| { - assert_eq!( - k[i] * group_elems[j], - circuit.extended_id_permutation[i * n + j] - ) + assert_eq!(k[i] * group_elems[j], circuit.extended_id_permutation[i * n + j]) }); }); @@ -1903,9 +1943,7 @@ pub(crate) mod test { assert!(circuit.compute_range_table_polynomial().is_err()); assert!(circuit.compute_key_table_polynomial().is_err()); assert!(circuit.compute_merged_lookup_table(F::one()).is_err()); - assert!(circuit - .compute_lookup_sorted_vec_polynomials(F::one(), &[]) - .is_err()); + assert!(circuit.compute_lookup_sorted_vec_polynomials(F::one(), &[]).is_err()); assert!(circuit .compute_lookup_prod_polynomial(&F::one(), &F::one(), &F::one(), &[], &[]) .is_err()); @@ -1928,9 +1966,7 @@ pub(crate) mod test { assert!(circuit.compute_extended_permutation_polynomials().is_err()); assert!(circuit.compute_pub_input_polynomial().is_err()); assert!(circuit.compute_wire_polynomials().is_err()); - assert!(circuit - .compute_prod_permutation_polynomial(&F::one(), &F::one()) - .is_err()); + assert!(circuit.compute_prod_permutation_polynomial(&F::one(), &F::one()).is_err()); // Should not insert gates or add variables after finalizing the circuit. circuit.finalize_for_arithmetization()?; @@ -1962,15 +1998,11 @@ pub(crate) mod test { assert!(circuit.compute_extended_permutation_polynomials().is_err()); assert!(circuit.compute_pub_input_polynomial().is_err()); assert!(circuit.compute_wire_polynomials().is_err()); - assert!(circuit - .compute_prod_permutation_polynomial(&F::one(), &F::one()) - .is_err()); + assert!(circuit.compute_prod_permutation_polynomial(&F::one(), &F::one()).is_err()); assert!(circuit.compute_range_table_polynomial().is_err()); assert!(circuit.compute_key_table_polynomial().is_err()); assert!(circuit.compute_merged_lookup_table(F::one()).is_err()); - assert!(circuit - .compute_lookup_sorted_vec_polynomials(F::one(), &[]) - .is_err()); + assert!(circuit.compute_lookup_sorted_vec_polynomials(F::one(), &[]).is_err()); assert!(circuit .compute_lookup_prod_polynomial(&F::one(), &F::one(), &F::one(), &[], &[]) .is_err()); @@ -2128,9 +2160,8 @@ pub(crate) mod test { // Check wire witness polynomials let wire_polys = circuit.compute_wire_polynomials()?; - for (poly, wire_vars) in wire_polys - .iter() - .zip(circuit.wire_variables.iter().take(circuit.num_wire_types())) + for (poly, wire_vars) in + wire_polys.iter().zip(circuit.wire_variables.iter().take(circuit.num_wire_types())) { let wire_evals: Vec = wire_vars.iter().map(|&var| circuit.witness[var]).collect(); check_polynomial(poly, &wire_evals); diff --git a/relation/src/errors.rs b/relation/src/errors.rs index 9b21346dc..db7fd9ece 100644 --- a/relation/src/errors.rs +++ b/relation/src/errors.rs @@ -20,6 +20,8 @@ pub enum CircuitError { PubInputLenMismatch(usize, usize), /// The {0}-th gate failed: {1} GateCheckFailure(usize, String), + /// The circuit layout specified by the link groups is invalid + Layout(String), /// The {0}-th link group was not allocated LinkGroupNotFound(String), /// Invalid parameters: {0} diff --git a/relation/src/gates/mod.rs b/relation/src/gates/mod.rs index c643f029e..c5818ef4a 100644 --- a/relation/src/gates/mod.rs +++ b/relation/src/gates/mod.rs @@ -83,6 +83,21 @@ impl fmt::Debug for (dyn Gate + 'static) { } } +/// A proof linking gate, for linking two proofs' witnesses together +#[derive(Copy, Clone, Debug)] +pub struct ProofLinkingGate; + +impl Gate for ProofLinkingGate { + fn name(&self) -> &'static str { + "Proof Linking Gate" + } + + // Represents a * 0 = 0 + fn q_mul(&self) -> [F; N_MUL_SELECTORS] { + [F::one(), F::zero()] + } +} + /// A empty gate for circuit padding #[derive(Debug, Clone)] pub struct PaddingGate; diff --git a/relation/src/traits.rs b/relation/src/traits.rs index d9541bf60..45d13a897 100644 --- a/relation/src/traits.rs +++ b/relation/src/traits.rs @@ -31,18 +31,17 @@ macro_rules! felt { /// circuit macro_rules! felts { ($x:expr) => { - $x.iter() - .map(|x| Self::Constant::from_field(x)) - .collect::>() + $x.iter().map(|x| Self::Constant::from_field(x)).collect::>() }; } /// Represents the parameterization of a proof-linking group in the circuit /// /// See `Circuit::create_link_group` for more details on proof-linking +#[derive(Clone, Debug)] pub struct LinkGroup { /// The id of the group - pub id: &'static str, + pub id: String, } /// An interface to add gates to a circuit that generalizes across wiring @@ -211,7 +210,7 @@ pub trait Circuit { /// where a(x) is the witness element. This allows us to prove /// that the a(x) polynomial of one proof equals the a(x) polynomial of /// another proof over some proof-linking domain, represented by the group - fn create_link_group(&mut self, id: &'static str, offset: Option) -> LinkGroup; + fn create_link_group(&mut self, id: String, offset: Option) -> LinkGroup; // --- Gate Allocation --- // @@ -347,11 +346,8 @@ pub trait Circuit { .collect::, CircuitError>>()?; // calculate y as the linear combination of coeffs and vals_in - let y_val = vals_in - .iter() - .zip(coeffs.iter()) - .map(|(val, coeff)| felt!(coeff) * val.clone()) - .sum(); + let y_val = + vals_in.iter().zip(coeffs.iter()).map(|(val, coeff)| felt!(coeff) * val.clone()).sum(); let y = self.create_variable(y_val)?; let wires = [wires_in[0], wires_in[1], wires_in[2], wires_in[3], y]; @@ -456,24 +452,14 @@ pub trait Circuit { let mut accum = padded[0]; for i in 1..padded_len / rate { accum = self.lc( - &[ - accum, - padded[rate * i - 2], - padded[rate * i - 1], - padded[rate * i], - ], + &[accum, padded[rate * i - 2], padded[rate * i - 1], padded[rate * i]], &coeffs, )?; } // Final round - let wires = [ - accum, - padded[padded_len - 3], - padded[padded_len - 2], - padded[padded_len - 1], - sum, - ]; + let wires = + [accum, padded[padded_len - 3], padded[padded_len - 2], padded[padded_len - 1], sum]; self.lc_gate(&wires, &coeffs)?; Ok(sum) @@ -493,9 +479,7 @@ pub trait Circuit { padded_wires.resize(n_lcs, self.zero()); padded_coeffs.resize(n_lcs, F::zero()); - for (wires, coeffs) in padded_wires - .chunks(GATE_WIDTH) - .zip(padded_coeffs.chunks(GATE_WIDTH)) + for (wires, coeffs) in padded_wires.chunks(GATE_WIDTH).zip(padded_coeffs.chunks(GATE_WIDTH)) { partials.push(self.lc( &[wires[0], wires[1], wires[2], wires[3]], diff --git a/rustfmt.toml b/rustfmt.toml index f364755bb..3b73cc5a9 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -5,6 +5,7 @@ use_try_shorthand = true match_block_trailing_comma = true use_field_init_shorthand = true max_width = 100 +use_small_heuristics = "Max" edition = "2018" condense_wildcard_suffixes = true imports_granularity = "Crate"