From 315592bf3f018e7dd2f1524f821839b2cdb62afa Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Tue, 14 Jan 2025 10:09:48 -0300 Subject: [PATCH] Implement bus constraints for small fields (#2334) This PR adds a `std::math::extension_field` module to our PIL standard library, which abstracts over which extension field is used. I changed all protocols - most importantly the bus - to use the new abstraction. As a result, we can now have bus constraints on smaller fields like BabyBear, e.g.: ``` $ cargo run pil test_data/asm/block_to_block_with_bus.asm -o output -f --field bb ``` TODOs left to future PRs: - Add M31 to `std::field::KnownField` - Implement witness generation for Bus with Fp4 --- pipeline/tests/powdr_std.rs | 6 ++ std/math/extension_field.asm | 82 +++++++++++++++++++ std/math/fp2.asm | 28 ++----- std/math/mod.asm | 3 +- std/protocols/bus.asm | 76 +++++++++-------- std/protocols/fingerprint.asm | 49 +++++++---- std/protocols/lookup.asm | 45 +++++----- std/protocols/permutation.asm | 49 ++++++----- test_data/asm/block_to_block_with_bus.asm | 2 - ...lock_to_block_with_bus_different_sizes.asm | 2 - test_data/std/fingerprint_test.asm | 11 ++- 11 files changed, 218 insertions(+), 135 deletions(-) create mode 100644 std/math/extension_field.asm diff --git a/pipeline/tests/powdr_std.rs b/pipeline/tests/powdr_std.rs index 7dadb21b83..30190079cf 100644 --- a/pipeline/tests/powdr_std.rs +++ b/pipeline/tests/powdr_std.rs @@ -181,6 +181,9 @@ fn permutation_via_challenges() { let pipeline = make_simple_prepared_pipeline::(f, LinkerMode::Bus); test_mock_backend(pipeline.clone()); test_plonky3_pipeline(pipeline); + + let pipeline = make_simple_prepared_pipeline::(f, LinkerMode::Bus); + test_plonky3_pipeline(pipeline); } #[test] @@ -189,6 +192,9 @@ fn lookup_via_challenges() { let pipeline = make_simple_prepared_pipeline::(f, LinkerMode::Bus); test_mock_backend(pipeline.clone()); test_plonky3_pipeline(pipeline); + + let pipeline = make_simple_prepared_pipeline::(f, LinkerMode::Bus); + test_plonky3_pipeline(pipeline); } #[test] diff --git a/std/math/extension_field.asm b/std/math/extension_field.asm new file mode 100644 index 0000000000..98803ac6a5 --- /dev/null +++ b/std/math/extension_field.asm @@ -0,0 +1,82 @@ +use std::field::known_field; +use std::field::KnownField; + +use std::array::len; +use std::check::panic; + +/// Whether we need to operate on an extension field (because the base field is too small). +let needs_extension: -> bool = || required_extension_size() > 1; + +/// How many field elements / field extensions are recommended for the current base field. +let required_extension_size: -> int = || match known_field() { + Option::Some(KnownField::Goldilocks) => 2, + Option::Some(KnownField::BN254) => 1, + Option::Some(KnownField::BabyBear) => 4, + Option::Some(KnownField::KoalaBear) => 4, + None => panic("The permutation/lookup argument is not implemented for the current field!") +}; + +/// Wrapper around Fp2 and Fp4 to abstract which extension field is used. +/// Once PIL supports traits, we can remove this type and the functions below. +enum Ext { + Fp2(std::math::fp2::Fp2), + Fp4(std::math::fp4::Fp4) +} + +let add_ext: Ext, Ext -> Ext = |a, b| match (a, b) { + (Ext::Fp2(aa), Ext::Fp2(bb)) => Ext::Fp2(std::math::fp2::add_ext(aa, bb)), + (Ext::Fp4(aa), Ext::Fp4(bb)) => Ext::Fp4(std::math::fp4::add_ext(aa, bb)), + _ => panic("Operands have different types") +}; + +let sub_ext: Ext, Ext -> Ext = |a, b| match (a, b) { + (Ext::Fp2(aa), Ext::Fp2(bb)) => Ext::Fp2(std::math::fp2::sub_ext(aa, bb)), + (Ext::Fp4(aa), Ext::Fp4(bb)) => Ext::Fp4(std::math::fp4::sub_ext(aa, bb)), + _ => panic("Operands have different types") +}; + +let mul_ext: Ext, Ext -> Ext = |a, b| match (a, b) { + (Ext::Fp2(aa), Ext::Fp2(bb)) => Ext::Fp2(std::math::fp2::mul_ext(aa, bb)), + (Ext::Fp4(aa), Ext::Fp4(bb)) => Ext::Fp4(std::math::fp4::mul_ext(aa, bb)), + _ => panic("Operands have different types") +}; + +let eval_ext: Ext -> Ext = query |a| match a { + Ext::Fp2(aa) => Ext::Fp2(std::math::fp2::eval_ext(aa)), + Ext::Fp4(aa) => Ext::Fp4(std::math::fp4::eval_ext(aa)), +}; + +let inv_ext: Ext -> Ext = query |a| match a { + Ext::Fp2(aa) => Ext::Fp2(std::math::fp2::inv_ext(aa)), + Ext::Fp4(aa) => Ext::Fp4(std::math::fp4::inv_ext(aa)), +}; + +let unpack_ext_array: Ext -> T[] = |a| match a { + Ext::Fp2(aa) => std::math::fp2::unpack_ext_array(aa), + Ext::Fp4(aa) => std::math::fp4::unpack_ext_array(aa), +}; + +let next_ext: Ext -> Ext = |a| match a { + Ext::Fp2(aa) => Ext::Fp2(std::math::fp2::next_ext(aa)), + Ext::Fp4(aa) => Ext::Fp4(std::math::fp4::next_ext(aa)), +}; + +let from_base: T -> Ext = |x| match required_extension_size() { + 1 => Ext::Fp2(std::math::fp2::from_base(x)), + 2 => Ext::Fp2(std::math::fp2::from_base(x)), + 4 => Ext::Fp4(std::math::fp4::from_base(x)), + _ => panic("Expected 1, 2, or 4") +}; + +let from_array: T[] -> Ext = |arr| match len(arr) { + 1 => Ext::Fp2(std::math::fp2::from_array(arr)), + 2 => Ext::Fp2(std::math::fp2::from_array(arr)), + 4 => Ext::Fp4(std::math::fp4::Fp4::Fp4(arr[0], arr[1], arr[2], arr[3])), + _ => panic("Expected 1, 2, or 4") +}; + +let constrain_eq_ext: Ext, Ext -> Constr[] = |a, b| match (a, b) { + (Ext::Fp2(aa), Ext::Fp2(bb)) => std::math::fp2::constrain_eq_ext(aa, bb), + (Ext::Fp4(aa), Ext::Fp4(bb)) => std::math::fp4::constrain_eq_ext(aa, bb), + _ => panic("Operands have different types") +}; \ No newline at end of file diff --git a/std/math/fp2.asm b/std/math/fp2.asm index 9c551fde9e..6a1cdc399e 100644 --- a/std/math/fp2.asm +++ b/std/math/fp2.asm @@ -8,6 +8,7 @@ use std::convert::expr; use std::field::known_field; use std::field::KnownField; use std::math::ff::inv_field; +use std::math::extension_field::needs_extension; use std::prover::eval; /// Corresponding Sage code to test irreducibility @@ -139,31 +140,14 @@ let unpack_ext_array: Fp2 -> T[] = |a| match a { Fp2::Fp2(a0, a1) => [a0, a1] }; -/// Whether we need to operate on the F_{p^2} extension field (because the current field is too small). -let needs_extension: -> bool = || required_extension_size() > 1; - -/// How many field elements / field extensions are recommended for the current base field. -let required_extension_size: -> int = || match known_field() { - Option::Some(KnownField::Goldilocks) => 2, - Option::Some(KnownField::BN254) => 1, - None => panic("The permutation/lookup argument is not implemented for the current field!") -}; - -/// Matches whether the length of a given array is correct to operate on the extension field -let is_extension = |arr| match len(arr) { - 1 => false, - 2 => true, - _ => panic("Expected 1 or 2 accumulator columns!") -}; - /// Constructs an extension field element `a0 + a1 * X` from either `[a0, a1]` or `[a0]` (setting `a1`to zero in that case) -let fp2_from_array = |arr| { - if is_extension(arr) { - Fp2::Fp2(arr[0], arr[1]) - } else { +let from_array: T[] -> Fp2 = |arr| match len(arr) { + 1 => { let _ = assert(!needs_extension(), || "The field is too small and needs to move to the extension field. Pass two elements instead!"); from_base(arr[0]) - } + }, + 2 => Fp2::Fp2(arr[0], arr[1]), + _ => panic("Expected array of length 1 or 2") }; mod test { diff --git a/std/math/mod.asm b/std/math/mod.asm index 6799ec5ce7..b4ff1fed1c 100644 --- a/std/math/mod.asm +++ b/std/math/mod.asm @@ -1,3 +1,4 @@ mod ff; mod fp2; -mod fp4; \ No newline at end of file +mod fp4; +mod extension_field; \ No newline at end of file diff --git a/std/protocols/bus.asm b/std/protocols/bus.asm index 0f0709e322..8b0dfb9ab1 100644 --- a/std/protocols/bus.asm +++ b/std/protocols/bus.asm @@ -1,20 +1,19 @@ use std::check::assert; use std::array; -use std::math::fp2::Fp2; -use std::math::fp2::add_ext; -use std::math::fp2::sub_ext; -use std::math::fp2::mul_ext; -use std::math::fp2::inv_ext; -use std::math::fp2::eval_ext; -use std::math::fp2::unpack_ext_array; -use std::math::fp2::next_ext; -use std::math::fp2::from_base; -use std::math::fp2::needs_extension; -use std::math::fp2::fp2_from_array; -use std::math::fp2::constrain_eq_ext; +use std::math::extension_field::Ext; +use std::math::extension_field::add_ext; +use std::math::extension_field::sub_ext; +use std::math::extension_field::mul_ext; +use std::math::extension_field::inv_ext; +use std::math::extension_field::eval_ext; +use std::math::extension_field::unpack_ext_array; +use std::math::extension_field::next_ext; +use std::math::extension_field::from_base; +use std::math::extension_field::from_array; +use std::math::extension_field::constrain_eq_ext; use std::protocols::fingerprint::fingerprint_with_id; use std::protocols::fingerprint::fingerprint_with_id_inter; -use std::math::fp2::required_extension_size; +use std::math::extension_field::required_extension_size; use std::prover::eval; use std::field::known_field; use std::field::KnownField; @@ -32,48 +31,53 @@ use std::check::panic; /// - latch: a binary expression which indicates where the multiplicity can be non-zero. let bus_interaction: expr, expr[], expr, expr -> () = constr |id, tuple, multiplicity, latch| { - std::check::assert(required_extension_size() <= 2, || "Invalid extension size"); - // Add phantom bus interaction let full_tuple = [id] + tuple; Constr::PhantomBusInteraction(multiplicity, full_tuple, latch); + let extension_field_size = required_extension_size(); + // Alpha is used to compress the LHS and RHS arrays. - let alpha = fp2_from_array(array::new(required_extension_size(), |i| challenge(0, i + 1))); + let alpha = from_array(array::new(extension_field_size, |i| challenge(0, i + 1))); // Beta is used to update the accumulator. - let beta = fp2_from_array(array::new(required_extension_size(), |i| challenge(0, i + 3))); + let beta = from_array(array::new(extension_field_size, |i| challenge(0, i + 1 + extension_field_size))); // Implemented as: folded = (beta - fingerprint(id, tuple...)); - let folded = match known_field() { - Option::Some(KnownField::Goldilocks) => { - // Materialized as a witness column for two reasons: - // - It makes sure the constraint degree is independent of the input tuple. - // - We can access folded', even if the tuple contains next references. - // Note that if all expressions are degree-1 and there is no next reference, - // this is wasteful, but we can't check that here. - let folded = fp2_from_array( - array::new(required_extension_size(), - |i| std::prover::new_witness_col_at_stage("folded", 1)) - ); - constrain_eq_ext(folded, sub_ext(beta, fingerprint_with_id_inter(id, tuple, alpha))); - folded - }, + let materialize_folded = match known_field() { + // Materialized as a witness column for two reasons: + // - It makes sure the constraint degree is independent of the input tuple. + // - We can access folded', even if the tuple contains next references. + // Note that if all expressions are degree-1 and there is no next reference, + // this is wasteful, but we can't check that here. + Option::Some(KnownField::Goldilocks) => true, + Option::Some(KnownField::BabyBear) => true, + Option::Some(KnownField::KoalaBear) => true, // The case above triggers our hand-written witness generation, but on Bn254, we'd not be // on the extension field and use the automatic witness generation. // However, it does not work with a materialized folded tuple. At the same time, Halo2 // (the only prover that supports BN254) does not have a hard degree bound. So, we can - // in-line the expression here. - Option::Some(KnownField::BN254) => sub_ext(beta, fingerprint_with_id_inter(id, tuple, alpha)), + // in-line the expression here. + Option::Some(KnownField::BN254) => false, _ => panic("Unexpected field!") }; + let folded = if materialize_folded { + let folded = from_array( + array::new(extension_field_size, + |i| std::prover::new_witness_col_at_stage("folded", 1)) + ); + constrain_eq_ext(folded, sub_ext(beta, fingerprint_with_id_inter(id, tuple, alpha))); + folded + } else { + sub_ext(beta, fingerprint_with_id_inter(id, tuple, alpha)) + }; let folded_next = next_ext(folded); let m_ext = from_base(multiplicity); let m_ext_next = next_ext(m_ext); - let acc = array::new(required_extension_size(), |i| std::prover::new_witness_col_at_stage("acc", 1)); - let acc_ext = fp2_from_array(acc); + let acc = array::new(extension_field_size, |i| std::prover::new_witness_col_at_stage("acc", 1)); + let acc_ext = from_array(acc); let next_acc = next_ext(acc_ext); let is_first: col = std::well_known::is_first; @@ -94,7 +98,7 @@ let bus_interaction: expr, expr[], expr, expr -> () = constr |id, tuple, multipl /// using extension field arithmetic. /// This is intended to be used as a hint in the extension field case; for the base case /// automatic witgen is smart enough to figure out the value of the accumulator. -let compute_next_z: expr, expr, expr[], expr, Fp2, Fp2, Fp2 -> fe[] = query |is_first, id, tuple, multiplicity, acc, alpha, beta| { +let compute_next_z: expr, expr, expr[], expr, Ext, Ext, Ext -> fe[] = query |is_first, id, tuple, multiplicity, acc, alpha, beta| { let m_next = eval(multiplicity'); let m_ext_next = from_base(m_next); diff --git a/std/protocols/fingerprint.asm b/std/protocols/fingerprint.asm index d39e6611e9..5b2274d80a 100644 --- a/std/protocols/fingerprint.asm +++ b/std/protocols/fingerprint.asm @@ -1,24 +1,23 @@ use std::array; use std::array::len; -use std::math::fp2::Fp2; -use std::math::fp2::add_ext; -use std::math::fp2::mul_ext; -use std::math::fp2::pow_ext; -use std::math::fp2::from_base; -use std::math::fp2::eval_ext; +use std::math::extension_field::Ext; +use std::math::extension_field::add_ext; +use std::math::extension_field::mul_ext; +use std::math::extension_field::from_base; +use std::math::extension_field::eval_ext; use std::check::assert; /// Maps [x_1, x_2, ..., x_n] to its Read-Solomon fingerprint, using a challenge alpha: $\sum_{i=1}^n alpha**{(n - i)} * x_i$ /// To generate an expression that computes the fingerprint, use `fingerprint_inter` instead. /// Note that alpha is passed as an expressions, so that it is only evaluated if needed (i.e., if len(expr_array) > 1). -let fingerprint: fe[], Fp2 -> Fp2 = query |expr_array, alpha| if array::len(expr_array) == 1 { +let fingerprint: fe[], Ext -> Ext = query |expr_array, alpha| if array::len(expr_array) == 1 { // No need to evaluate `alpha` (which would be removed by the optimizer). from_base(expr_array[0]) } else { fingerprint_impl(expr_array, eval_ext(alpha), len(expr_array)) }; -let fingerprint_impl: fe[], Fp2, int -> Fp2 = query |expr_array, alpha, l| if l == 1 { +let fingerprint_impl: fe[], Ext, int -> Ext = query |expr_array, alpha, l| if l == 1 { // Base case from_base(expr_array[0]) } else { @@ -30,7 +29,7 @@ let fingerprint_impl: fe[], Fp2, int -> Fp2 = query |expr_array, alpha, /// Like `fingerprint`, but "materializes" the intermediate results as intermediate columns. /// Inlining them would lead to an exponentially-sized expression. -let fingerprint_inter: expr[], Fp2 -> Fp2 = |expr_array, alpha| if len(expr_array) == 1 { +let fingerprint_inter: expr[], Ext -> Ext = |expr_array, alpha| if len(expr_array) == 1 { // Base case from_base(expr_array[0]) } else { @@ -38,35 +37,49 @@ let fingerprint_inter: expr[], Fp2 -> Fp2 = |expr_array, alpha| if l // Recursively compute the fingerprint as fingerprint(expr_array[:-1], alpha) * alpha + expr_array[-1] let intermediate_fingerprint = match fingerprint_inter(array::sub_array(expr_array, 0, len(expr_array) - 1), alpha) { - Fp2::Fp2(a0, a1) => { + Ext::Fp2(std::math::fp2::Fp2::Fp2(a0, a1)) => { let intermediate_fingerprint_0: inter = a0; let intermediate_fingerprint_1: inter = a1; - Fp2::Fp2(intermediate_fingerprint_0, intermediate_fingerprint_1) + Ext::Fp2(std::math::fp2::Fp2::Fp2(intermediate_fingerprint_0, intermediate_fingerprint_1)) + }, + Ext::Fp4(std::math::fp4::Fp4::Fp4(a0, a1, a2, a3)) => { + let intermediate_fingerprint_0: inter = a0; + let intermediate_fingerprint_1: inter = a1; + let intermediate_fingerprint_2: inter = a2; + let intermediate_fingerprint_3: inter = a3; + Ext::Fp4(std::math::fp4::Fp4::Fp4(intermediate_fingerprint_0, intermediate_fingerprint_1, intermediate_fingerprint_2, intermediate_fingerprint_3)) } }; add_ext(mul_ext(alpha, intermediate_fingerprint), from_base(expr_array[len(expr_array) - 1])) }; /// Maps [id, x_1, x_2, ..., x_n] to its Read-Solomon fingerprint, using a challenge alpha: $\sum_{i=1}^n alpha**{(n - i)} * x_i$ -let fingerprint_with_id: fe, fe[], Fp2 -> Fp2 = query |id, expr_array, alpha| fingerprint([id] + expr_array, alpha); +let fingerprint_with_id: fe, fe[], Ext -> Ext = query |id, expr_array, alpha| fingerprint([id] + expr_array, alpha); /// Maps [id, x_1, x_2, ..., x_n] to its Read-Solomon fingerprint, using a challenge alpha: $\sum_{i=1}^n alpha**{(n - i)} * x_i$ -let fingerprint_with_id_inter: expr, expr[], Fp2 -> Fp2 = |id, expr_array, alpha| fingerprint_inter([id] + expr_array, alpha); +let fingerprint_with_id_inter: expr, expr[], Ext -> Ext = |id, expr_array, alpha| fingerprint_inter([id] + expr_array, alpha); mod test { use super::fingerprint; use std::check::assert; - use std::math::fp2::Fp2; - use std::math::fp2::from_base; + use std::math::extension_field::Ext; + use std::math::extension_field::from_base; + use std::check::panic; /// Helper function to assert that the fingerprint of a tuple is equal to the expected value. let assert_fingerprint_equal: fe[], expr, fe -> () = query |tuple, challenge, expected| { let result = fingerprint(tuple, from_base(challenge)); match result { - Fp2::Fp2(actual, should_be_zero) => { - assert(should_be_zero == 0, || "Returned an extension field element"); + Ext::Fp2(std::math::fp2::Fp2::Fp2(actual, zero)) => { + assert(zero == 0, || "Returned an extension field element"); + assert(expected == actual, || "expected != actual"); + }, + Ext::Fp4(std::math::fp4::Fp4::Fp4(actual, zero1, zero2, zero3)) => { + assert(zero1 == 0, || "Returned an extension field element"); + assert(zero2 == 0, || "Returned an extension field element"); + assert(zero3 == 0, || "Returned an extension field element"); assert(expected == actual, || "expected != actual"); - } + }, } }; diff --git a/std/protocols/lookup.asm b/std/protocols/lookup.asm index 7bfdf5a087..6546b8ff13 100644 --- a/std/protocols/lookup.asm +++ b/std/protocols/lookup.asm @@ -5,20 +5,19 @@ use std::array::map; use std::check::assert; use std::check::panic; use std::constraints::to_phantom_lookup; -use std::math::fp2::Fp2; -use std::math::fp2::add_ext; -use std::math::fp2::sub_ext; -use std::math::fp2::mul_ext; -use std::math::fp2::unpack_ext; -use std::math::fp2::unpack_ext_array; -use std::math::fp2::next_ext; -use std::math::fp2::inv_ext; -use std::math::fp2::eval_ext; -use std::math::fp2::from_base; -use std::math::fp2::fp2_from_array; -use std::math::fp2::constrain_eq_ext; -use std::math::fp2::required_extension_size; -use std::math::fp2::needs_extension; +use std::math::extension_field::Ext; +use std::math::extension_field::add_ext; +use std::math::extension_field::sub_ext; +use std::math::extension_field::mul_ext; +use std::math::extension_field::unpack_ext_array; +use std::math::extension_field::next_ext; +use std::math::extension_field::inv_ext; +use std::math::extension_field::eval_ext; +use std::math::extension_field::from_base; +use std::math::extension_field::from_array; +use std::math::extension_field::constrain_eq_ext; +use std::math::extension_field::required_extension_size; +use std::math::extension_field::needs_extension; use std::protocols::fingerprint::fingerprint; use std::protocols::fingerprint::fingerprint_inter; use std::utils::unwrap_or_else; @@ -34,7 +33,7 @@ let unpack_lookup_constraint: Constr -> (expr, expr[], expr, expr[]) = |lookup_c }; /// Compute z' = z + 1/(beta-a_i) * lhs_selector - m_i/(beta-b_i) * rhs_selector, using extension field arithmetic -let compute_next_z: Fp2, Fp2, Fp2, Constr, expr -> fe[] = query |acc, alpha, beta, lookup_constraint, multiplicities| { +let compute_next_z: Ext, Ext, Ext, Constr, expr -> fe[] = query |acc, alpha, beta, lookup_constraint, multiplicities| { let (lhs_selector, lhs, rhs_selector, rhs) = unpack_lookup_constraint(lookup_constraint); let lhs_denom = sub_ext(eval_ext(beta), fingerprint(array::eval(lhs), alpha)); @@ -60,11 +59,13 @@ let compute_next_z: Fp2, Fp2, Fp2, Constr, expr -> fe[] = quer /// higher-stage witness columns. /// Use this function if the backend does not support lookup constraints natively. let lookup: Constr -> () = constr |lookup_constraint| { - std::check::assert(required_extension_size() <= 2, || "Invalid extension size"); + + let extension_field_size = required_extension_size(); + // Alpha is used to compress the LHS and RHS arrays. - let alpha = fp2_from_array(array::new(required_extension_size(), |i| challenge(0, i + 1))); + let alpha = from_array(array::new(extension_field_size, |i| challenge(0, i + 1))); // Beta is used to update the accumulator. - let beta = fp2_from_array(array::new(required_extension_size(), |i| challenge(0, i + 3))); + let beta = from_array(array::new(extension_field_size, |i| challenge(0, i + 1 + extension_field_size))); let (lhs_selector, lhs, rhs_selector, rhs) = unpack_lookup_constraint(lookup_constraint); @@ -73,8 +74,8 @@ let lookup: Constr -> () = constr |lookup_constraint| { let multiplicities; let m_ext = from_base(multiplicities); - let acc = array::new(required_extension_size(), |i| std::prover::new_witness_col_at_stage("acc", 1)); - let acc_ext = fp2_from_array(acc); + let acc = array::new(extension_field_size, |i| std::prover::new_witness_col_at_stage("acc", 1)); + let acc_ext = from_array(acc); let next_acc = next_ext(acc_ext); // Update rule: @@ -102,11 +103,9 @@ let lookup: Constr -> () = constr |lookup_constraint| { let is_first: col = std::well_known::is_first; - let (acc_1, acc_2) = unpack_ext(acc_ext); // First and last acc needs to be 0 // (because of wrapping, the acc[0] and acc[N] are the same) - is_first * acc_1 = 0; - is_first * acc_2 = 0; + array::new(array::len(acc), |i| is_first * acc[i] = 0); constrain_eq_ext(update_expr, from_base(0)); // Add an annotation for witness generation diff --git a/std/protocols/permutation.asm b/std/protocols/permutation.asm index 37eaa40213..d08af97931 100644 --- a/std/protocols/permutation.asm +++ b/std/protocols/permutation.asm @@ -1,20 +1,19 @@ use std::array::map; use std::check::assert; use std::check::panic; -use std::math::fp2::Fp2; -use std::math::fp2::add_ext; -use std::math::fp2::sub_ext; -use std::math::fp2::mul_ext; -use std::math::fp2::unpack_ext; -use std::math::fp2::unpack_ext_array; -use std::math::fp2::next_ext; -use std::math::fp2::inv_ext; -use std::math::fp2::eval_ext; -use std::math::fp2::from_base; -use std::math::fp2::fp2_from_array; -use std::math::fp2::constrain_eq_ext; -use std::math::fp2::required_extension_size; -use std::math::fp2::needs_extension; +use std::math::extension_field::Ext; +use std::math::extension_field::add_ext; +use std::math::extension_field::sub_ext; +use std::math::extension_field::mul_ext; +use std::math::extension_field::unpack_ext_array; +use std::math::extension_field::next_ext; +use std::math::extension_field::inv_ext; +use std::math::extension_field::eval_ext; +use std::math::extension_field::from_base; +use std::math::extension_field::from_array; +use std::math::extension_field::constrain_eq_ext; +use std::math::extension_field::required_extension_size; +use std::math::extension_field::needs_extension; use std::protocols::fingerprint::fingerprint; use std::protocols::fingerprint::fingerprint_inter; use std::prover::eval; @@ -34,14 +33,14 @@ let unpack_permutation_constraint: Constr -> (expr, expr[], expr, expr[]) = |per /// Takes a boolean selector (0/1) and a value, returns equivalent of `if selector { value } else { 1 }` /// Implemented as: selector * (value - 1) + 1 -let selected_or_one: T, Fp2 -> Fp2 = |selector, value| add_ext(mul_ext(from_base(selector), sub_ext(value, from_base(1))), from_base(1)); +let selected_or_one: T, Ext -> Ext = |selector, value| add_ext(mul_ext(from_base(selector), sub_ext(value, from_base(1))), from_base(1)); /// Compute acc' = acc * selected_or_one(sel_a, beta - a) / selected_or_one(sel_b, beta - b), /// using extension field arithmetic (where expressions for sel_a, a, sel_b, b are derived from /// the provided permutation constraint). /// This is intended to be used as a hint in the extension field case; for the base case /// automatic witgen is smart enough to figure out the value of the accumulator. -let compute_next_z: Fp2, Fp2, Fp2, Constr -> fe[] = query |acc, alpha, beta, permutation_constraint| { +let compute_next_z: Ext, Ext, Ext, Constr -> fe[] = query |acc, alpha, beta, permutation_constraint| { let (lhs_selector, lhs, rhs_selector, rhs) = unpack_permutation_constraint(permutation_constraint); @@ -75,11 +74,13 @@ let compute_next_z: Fp2, Fp2, Fp2, Constr -> fe[] = query |acc /// accumulator is the same as the first one, because of wrapping. /// For small fields, this computation should happen in the extension field. let permutation: Constr -> () = constr |permutation_constraint| { - std::check::assert(required_extension_size() <= 2, || "Invalid extension size"); + + let extension_field_size = required_extension_size(); + // Alpha is used to compress the LHS and RHS arrays - let alpha = fp2_from_array(std::array::new(required_extension_size(), |i| challenge(0, i + 1))); + let alpha = from_array(std::array::new(extension_field_size, |i| challenge(0, i + 1))); // Beta is used to update the accumulator - let beta = fp2_from_array(std::array::new(required_extension_size(), |i| challenge(0, i + 3))); + let beta = from_array(std::array::new(extension_field_size, |i| challenge(0, i + 1 + extension_field_size))); let (lhs_selector, lhs, rhs_selector, rhs) = unpack_permutation_constraint(permutation_constraint); @@ -89,8 +90,8 @@ let permutation: Constr -> () = constr |permutation_constraint| { let lhs_folded = selected_or_one(lhs_selector, sub_ext(beta, fingerprint_inter(lhs, alpha))); let rhs_folded = selected_or_one(rhs_selector, sub_ext(beta, fingerprint_inter(rhs, alpha))); - let acc = std::array::new(required_extension_size(), |i| std::prover::new_witness_col_at_stage("acc", 1)); - let acc_ext = fp2_from_array(acc); + let acc = std::array::new(extension_field_size, |i| std::prover::new_witness_col_at_stage("acc", 1)); + let acc_ext = from_array(acc); let next_acc = next_ext(acc_ext); // Update rule: @@ -103,12 +104,10 @@ let permutation: Constr -> () = constr |permutation_constraint| { let is_first: col = std::well_known::is_first; - let (acc_1, acc_2) = unpack_ext(acc_ext); - // First and last acc needs to be 1 // (because of wrapping, the acc[0] and acc[N] are the same) - is_first * (acc_1 - 1) = 0; - is_first * acc_2 = 0; + is_first * (acc[0] - 1) = 0; + array::new(array::len(acc) - 1, |i| is_first * acc[i + 1] = 0); constrain_eq_ext(update_expr, from_base(0)); // Add an annotation for witness generation diff --git a/test_data/asm/block_to_block_with_bus.asm b/test_data/asm/block_to_block_with_bus.asm index ed78aefca3..f7423a6e9c 100644 --- a/test_data/asm/block_to_block_with_bus.asm +++ b/test_data/asm/block_to_block_with_bus.asm @@ -1,7 +1,5 @@ use std::protocols::bus::bus_receive; use std::protocols::bus::bus_send; -use std::math::fp2::Fp2; -use std::math::fp2::from_base; use std::prelude::Query; use std::prover::challenge; diff --git a/test_data/asm/block_to_block_with_bus_different_sizes.asm b/test_data/asm/block_to_block_with_bus_different_sizes.asm index 05b001a61c..bac8ca29f4 100644 --- a/test_data/asm/block_to_block_with_bus_different_sizes.asm +++ b/test_data/asm/block_to_block_with_bus_different_sizes.asm @@ -1,7 +1,5 @@ use std::protocols::bus::bus_receive; use std::protocols::bus::bus_send; -use std::math::fp2::Fp2; -use std::math::fp2::from_base; use std::prelude::Query; use std::prover::challenge; diff --git a/test_data/std/fingerprint_test.asm b/test_data/std/fingerprint_test.asm index bb15305d44..fc4ea96529 100644 --- a/test_data/std/fingerprint_test.asm +++ b/test_data/std/fingerprint_test.asm @@ -1,8 +1,7 @@ -use std::math::fp2::from_base; use std::math::fp2::Fp2; -use std::math::fp2::eval_ext; -use std::math::fp2::unpack_ext_array; -use std::math::fp2::constrain_eq_ext; +use std::math::extension_field::Ext; +use std::math::extension_field::unpack_ext_array; +use std::math::extension_field::constrain_eq_ext; use std::prover::challenge; use std::protocols::fingerprint::fingerprint; use std::protocols::fingerprint::fingerprint_inter; @@ -21,8 +20,8 @@ machine Main with degree: 2048 { // Add `fingerprint_value` witness columns and constrain them using `fingerprint_inter` col witness stage(1) fingerprint_value0, fingerprint_value1; - let fingerprint_value = Fp2::Fp2(fingerprint_value0, fingerprint_value1); - let alpha = Fp2::Fp2(challenge(0, 0), challenge(0, 1)); + let fingerprint_value = Ext::Fp2(Fp2::Fp2(fingerprint_value0, fingerprint_value1)); + let alpha = Ext::Fp2(Fp2::Fp2(challenge(0, 0), challenge(0, 1))); constrain_eq_ext(fingerprint_inter(tuple, alpha), fingerprint_value); // Add `fingerprint_value_hint` witness columns and compute the fingerprint in a hint using `fingerprint`