Skip to content

Commit

Permalink
relation: constraint-system: Add proof layout interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Joey Kraut authored and joeykraut committed Dec 27, 2023
1 parent b8e2b8a commit 4c1516a
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 265 deletions.
61 changes: 26 additions & 35 deletions plonk/src/multiprover/proof_system/constraint_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,13 @@ where
// do so by adding proof-linking gates at specific indices in the circuit's arithmetization
// of the form a(x) * 0 = 0, where a(x) encodes the witness to be linked. We can then
// invoke a polynomial subprotocol to prove that the a(x) polynomial is the same between
// proofs. The following fields are used to track membership in groups and placement of groups
// in the arithmetization
// proofs. The following fields are used to track membership in groups and placement of
// groups in the arithmetization
/// The proof-linking group layouts for the circuit. Maps a group ID to the
/// indices of the witness contained in the group
link_groups: HashMap<&'static str, Vec<Variable>>,
link_groups: HashMap<String, Vec<Variable>>,
/// The offsets at which to place the link groups in the arithmetization
link_group_offsets: HashMap<String, usize>,

/// The underlying MPC fabric that this circuit is allocated within
fabric: MpcFabric<C>,
Expand Down Expand Up @@ -164,6 +166,7 @@ where
// This is later updated
eval_domain: Radix2EvaluationDomain::new(1 /* num_coeffs */).unwrap(),
link_groups: HashMap::new(),
link_group_offsets: HashMap::new(),
fabric,
};

Expand Down Expand Up @@ -210,7 +213,7 @@ where

for link_group in link_groups {
self.link_groups
.get_mut(link_group.id)
.get_mut(&link_group.id)
.ok_or(CircuitError::LinkGroupNotFound(link_group.id.to_string()))?
.push(var);
}
Expand Down Expand Up @@ -307,9 +310,8 @@ where
pub_input: &AuthenticatedScalarResult<C>,
) -> ScalarResult<C> {
// Compute wire values
let w_vals = (0..=GATE_WIDTH)
.map(|i| &self.witness[self.wire_variables[i][gate_id]])
.collect_vec();
let w_vals =
(0..=GATE_WIDTH).map(|i| &self.witness[self.wire_variables[i][gate_id]]).collect_vec();

// Compute selector values
macro_rules! as_scalars {
Expand Down Expand Up @@ -387,11 +389,8 @@ where

// Compute the mapping from variables to wires.
let mut variable_wires_map = vec![vec![]; m];
for (gate_wire_id, variables) in self
.wire_variables
.iter()
.take(self.num_wire_types())
.enumerate()
for (gate_wire_id, variables) in
self.wire_variables.iter().take(self.num_wire_types()).enumerate()
{
for (gate_id, &var) in variables.iter().enumerate() {
variable_wires_map[var].push((gate_wire_id, gate_id));
Expand Down Expand Up @@ -574,10 +573,7 @@ where
) -> Result<(), CircuitError> {
let n = public_input.len();
if n != self.num_inputs() {
return Err(CircuitError::PubInputLenMismatch(
n,
self.pub_input_gate_ids.len(),
));
return Err(CircuitError::PubInputLenMismatch(n, self.pub_input_gate_ids.len()));
}

let mut gate_results = Vec::new();
Expand Down Expand Up @@ -605,10 +601,7 @@ where
if res == Scalar::zero() {
Ok(())
} else {
Err(CircuitError::GateCheckFailure(
idx,
"gate check failed".to_string(),
))
Err(CircuitError::GateCheckFailure(idx, "gate check failed".to_string()))
}
})
.collect::<Result<Vec<_>, CircuitError>>()
Expand Down Expand Up @@ -641,9 +634,8 @@ where
link_groups: &[LinkGroup],
) -> Result<Variable, CircuitError> {
let authenticated_val = self.fabric.one_authenticated() * Scalar::new(val);
let var = self.create_variable(authenticated_val)?;
let var = self.create_variable_with_link_groups(authenticated_val, link_groups)?;
self.enforce_constant(var, val)?;
self.add_to_link_groups(var, link_groups)?;

Ok(var)
}
Expand Down Expand Up @@ -673,9 +665,12 @@ where
Ok(())
}

// TODO: properly handle offsets
fn create_link_group(&mut self, id: &'static str, _offset: Option<usize>) -> LinkGroup {
self.link_groups.insert(id, Vec::new());
fn create_link_group(&mut self, id: String, offset: Option<usize>) -> LinkGroup {
self.link_groups.insert(id.clone(), Vec::new());
if let Some(offset) = offset {
self.link_group_offsets.insert(id.clone(), offset);
}

LinkGroup { id }
}

Expand All @@ -686,9 +681,8 @@ where
) -> Result<(), CircuitError> {
self.check_finalize_flag(false)?;

for (wire_var, wire_variable) in wire_vars
.iter()
.zip(self.wire_variables.iter_mut().take(GATE_WIDTH + 1))
for (wire_var, wire_variable) in
wire_vars.iter().zip(self.wire_variables.iter_mut().take(GATE_WIDTH + 1))
{
wire_variable.push(*wire_var)
}
Expand Down Expand Up @@ -855,9 +849,8 @@ where
let mut numerator_terms = Vec::with_capacity(self.num_wire_types());
let mut denominator_terms = Vec::with_capacity(self.num_wire_types());
for i in 0..self.num_wire_types() {
let wire_values = (0..(n - 1))
.map(|j| self.witness[self.wire_variable(i, j)].clone())
.collect_vec();
let wire_values =
(0..(n - 1)).map(|j| self.witness[self.wire_variable(i, j)].clone()).collect_vec();
let id_perm_values = (0..(n - 1))
.map(|j| Scalar::new(self.extended_id_permutation[i * n + j]))
.collect_vec();
Expand Down Expand Up @@ -914,10 +907,8 @@ where
.iter()
.take(self.num_wire_types())
.map(|wire_vars| {
let wire_vec: Vec<AuthenticatedScalarResult<C>> = wire_vars
.iter()
.map(|&var| witness[var].clone())
.collect_vec();
let wire_vec: Vec<AuthenticatedScalarResult<C>> =
wire_vars.iter().map(|&var| witness[var].clone()).collect_vec();

let coeffs = AuthenticatedScalarResult::ifft::<
Radix2EvaluationDomain<C::ScalarField>,
Expand Down
109 changes: 32 additions & 77 deletions plonk/src/multiprover/proof_system/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@ use super::{MpcArithmetization, MpcOracles};
/// A type alias for a bundle of commitments and polynomials
/// TODO: Remove this lint allowance
#[allow(type_alias_bounds)]
type MpcCommitmentsAndPolys<E: Pairing> = (
Vec<MultiproverKzgCommitment<E>>,
Vec<AuthenticatedDensePoly<E::G1>>,
);
type MpcCommitmentsAndPolys<E: Pairing> =
(Vec<MultiproverKzgCommitment<E>>, Vec<AuthenticatedDensePoly<E::G1>>);

/// A Plonk IOP prover over a secret shared algebra
/// TODO: Remove this lint allowance
Expand Down Expand Up @@ -85,11 +83,7 @@ impl<E: Pairing> MpcProver<E> {
)
.ok_or(PlonkError::DomainCreationError)?;

Ok(Self {
domain,
quot_domain,
fabric,
})
Ok(Self { domain, quot_domain, fabric })
}

/// Round 1:
Expand Down Expand Up @@ -172,11 +166,8 @@ impl<E: Pairing> MpcProver<E> {
online_oracles: &MpcOracles<E::G1>,
num_wire_types: usize,
) -> MpcProofEvaluations<E::G1> {
let wires_evals: Vec<AuthenticatedScalarResult<E::G1>> = online_oracles
.wire_polys
.iter()
.map(|poly| poly.eval(&challenges.zeta))
.collect();
let wires_evals: Vec<AuthenticatedScalarResult<E::G1>> =
online_oracles.wire_polys.iter().map(|poly| poly.eval(&challenges.zeta)).collect();

let wire_sigma_evals: Vec<ScalarResult<E::G1>> = pk
.sigmas
Expand Down Expand Up @@ -370,23 +361,14 @@ impl<E: Pairing> MpcProver<E> {
.collect();

// The coset we use to compute the quotient polynomial
let coset = self
.quot_domain
.get_coset(E::ScalarField::GENERATOR)
.unwrap();
let coset = self.quot_domain.get_coset(E::ScalarField::GENERATOR).unwrap();

// Compute evaluations of the selectors, permutations, and wiring polynomials
let selectors_coset_fft: Vec<Vec<E::ScalarField>> = pk
.selectors
.iter()
.map(|poly| coset.fft(poly.coeffs()))
.collect();
let selectors_coset_fft: Vec<Vec<E::ScalarField>> =
pk.selectors.iter().map(|poly| coset.fft(poly.coeffs())).collect();

let sigmas_coset_fft: Vec<Vec<E::ScalarField>> = pk
.sigmas
.iter()
.map(|poly| coset.fft(poly.coeffs()))
.collect();
let sigmas_coset_fft: Vec<Vec<E::ScalarField>> =
pk.sigmas.iter().map(|poly| coset.fft(poly.coeffs())).collect();

let wire_polys_coset_fft: Vec<Vec<AuthenticatedScalarResult<E::G1>>> = online_oracles
.wire_polys
Expand Down Expand Up @@ -504,10 +486,7 @@ impl<E: Pairing> MpcProver<E> {
let mut nonzero_wires: Vec<Vec<AuthenticatedScalarResult<E::G1>>> = vec![vec![]; m];
for (i, selector) in selectors.iter().enumerate() {
if !selector.is_zero() {
wires
.iter()
.enumerate()
.for_each(|(j, w)| nonzero_wires[j].push(w[i].clone()));
wires.iter().enumerate().for_each(|(j, w)| nonzero_wires[j].push(w[i].clone()));
nonzero_sel.push(Scalar::new(*selector));
nonzero_indices.push(i);
}
Expand Down Expand Up @@ -593,10 +572,7 @@ impl<E: Pairing> MpcProver<E> {
prod_perm_poly_coset_evals: &[AuthenticatedScalarResult<E::G1>],
challenges: &MpcChallenges<E::G1>,
sigmas_coset_fft: &[Vec<E::ScalarField>],
) -> (
Vec<AuthenticatedScalarResult<E::G1>>,
Vec<AuthenticatedScalarResult<E::G1>>,
) {
) -> (Vec<AuthenticatedScalarResult<E::G1>>, Vec<AuthenticatedScalarResult<E::G1>>) {
let n = pk.domain_size();
let m = self.quot_domain.size();

Expand All @@ -607,9 +583,8 @@ impl<E: Pairing> MpcProver<E> {

// Construct the evaluations shifted by the coset generators
let coset_generators = pk.k().iter().copied().collect_vec();
let eval_points = (0..m)
.map(|i| self.quot_domain.element(i) * E::ScalarField::GENERATOR)
.collect_vec();
let eval_points =
(0..m).map(|i| self.quot_domain.element(i) * E::ScalarField::GENERATOR).collect_vec();

let mut all_evals = Vec::with_capacity(num_wire_types);
for generator in coset_generators.iter() {
Expand Down Expand Up @@ -720,11 +695,8 @@ impl<E: Pairing> MpcProver<E> {
// chunks of degree n + 1 contiguous coefficients
let mut split_quot_polys: Vec<AuthenticatedDensePoly<E::G1>> = (0..num_wire_types)
.map(|i| {
let end = if i < num_wire_types - 1 {
(i + 1) * (n + 2)
} else {
quot_poly.degree() + 1
};
let end =
if i < num_wire_types - 1 { (i + 1) * (n + 2) } else { quot_poly.degree() + 1 };

// Degree-(n+1) polynomial has n + 2 coefficients.
AuthenticatedDensePoly::from_coeffs(quot_poly.coeffs[i * (n + 2)..end].to_vec())
Expand All @@ -738,22 +710,17 @@ impl<E: Pairing> MpcProver<E> {
// with t_lowest_i(X) = t_lowest_i(X) - 0 + b_now_i * X^(n+2)
// and t_highest_i(X) = t_highest_i(X) - b_last_i
let mut last_randomizer = self.fabric.zero_authenticated();
let mut randomizers = self
.fabric
.random_shared_scalars_authenticated(num_wire_types - 1);
let mut randomizers = self.fabric.random_shared_scalars_authenticated(num_wire_types - 1);

split_quot_polys
.iter_mut()
.take(num_wire_types - 1)
.for_each(|poly| {
poly.coeffs[0] = &poly.coeffs[0] - &last_randomizer;
assert_eq!(poly.degree(), n + 1);
split_quot_polys.iter_mut().take(num_wire_types - 1).for_each(|poly| {
poly.coeffs[0] = &poly.coeffs[0] - &last_randomizer;
assert_eq!(poly.degree(), n + 1);

let next_randomizer = randomizers.pop().unwrap();
poly.coeffs.push(next_randomizer.clone());
let next_randomizer = randomizers.pop().unwrap();
poly.coeffs.push(next_randomizer.clone());

last_randomizer = next_randomizer;
});
last_randomizer = next_randomizer;
});

// Mask the highest splitting poly
split_quot_polys[num_wire_types - 1].coeffs[0] =
Expand Down Expand Up @@ -894,9 +861,7 @@ pub fn element_wise_product<C: CurveGroup>(
// If we choose to view the vectors as tiling the columns of matrices, each step
// in this fold replaces the first and second columns with their
// element-wise product
vectors[2..].iter().fold(initial, |acc, vec| {
AuthenticatedScalarResult::batch_mul(&acc, vec)
})
vectors[2..].iter().fold(initial, |acc, vec| AuthenticatedScalarResult::batch_mul(&acc, vec))
}

/// Take the element-wise sum of a set of vectors
Expand All @@ -918,9 +883,7 @@ fn element_wise_sum<C: CurveGroup>(
// If we choose to view the vectors as tiling the columns of matrices, each step
// in this fold replaces the first and second columns with their
// element-wise sum
vectors[2..].iter().fold(initial, |acc, vec| {
AuthenticatedScalarResult::batch_add(&acc, vec)
})
vectors[2..].iter().fold(initial, |acc, vec| AuthenticatedScalarResult::batch_add(&acc, vec))
}

/// Evaluate a public polynomial on a result in the MPC fabric
Expand Down Expand Up @@ -1055,10 +1018,8 @@ pub(crate) mod test {
return AuthenticatedDensePoly::from_coeffs(vec![fabric.zero_authenticated()]);
}

let coeffs = fabric.batch_share_scalar(
poly.coeffs.iter().cloned().map(Scalar::new).collect(),
PARTY0,
);
let coeffs = fabric
.batch_share_scalar(poly.coeffs.iter().cloned().map(Scalar::new).collect(), PARTY0);
AuthenticatedDensePoly::from_coeffs(coeffs)
}

Expand Down Expand Up @@ -1280,7 +1241,7 @@ pub(crate) mod test {
/// Run the fifth round of a single-prover circuit
///
/// Returns the commitments to the opening and shifted opening polynomials
fn run_fifth_round(params: &mut TestParams) -> (Commitment<TestCurve>, Commitment<TestCurve>) {
fn run_fifth_round(params: &TestParams) -> (Commitment<TestCurve>, Commitment<TestCurve>) {
params
.prover
.compute_opening_proofs(
Expand Down Expand Up @@ -1330,10 +1291,7 @@ pub(crate) mod test {
.collect::<Vec<_>>();
let pub_poly = pub_poly.open_authenticated();

(
(wire_comms_open.await, wire_polys_open.await),
pub_poly.await.unwrap(),
)
((wire_comms_open.await, wire_polys_open.await), pub_poly.await.unwrap())
}
})
.await;
Expand Down Expand Up @@ -1372,10 +1330,7 @@ pub(crate) mod test {
.unwrap();

// Open the results
(
perm_commit.open_authenticated().await.unwrap(),
perm_poly.open().await,
)
(perm_commit.open_authenticated().await.unwrap(), perm_poly.open().await)
}
})
.await;
Expand Down Expand Up @@ -1508,7 +1463,7 @@ pub(crate) mod test {
run_second_round(true /* mask */, &mut params);
run_third_round(false /* mask */, &mut params);
run_fourth_round(&mut params);
let (expected_open, expected_shift) = run_fifth_round(&mut params);
let (expected_open, expected_shift) = run_fifth_round(&params);

// Compute the result in an MPC
let ((open, shifted_open), _) = execute_mock_mpc(|fabric| {
Expand Down
Loading

0 comments on commit 4c1516a

Please sign in to comment.