Skip to content

Commit

Permalink
Implement bus constraints for small fields (#2334)
Browse files Browse the repository at this point in the history
This PR adds a `std::math::extension_field` module to our PIL standard
library, which abstracts over which extension field is used. I changed
all protocols - most importantly the bus - to use the new abstraction.
As a result, we can now have bus constraints on smaller fields like
BabyBear, e.g.:
```
$ cargo run pil test_data/asm/block_to_block_with_bus.asm -o output -f --field bb
```

TODOs left to future PRs:
- Add M31 to `std::field::KnownField`
- Implement witness generation for Bus with Fp4
  • Loading branch information
georgwiese authored Jan 14, 2025
1 parent 9f0285d commit 315592b
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 135 deletions.
6 changes: 6 additions & 0 deletions pipeline/tests/powdr_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ fn permutation_via_challenges() {
let pipeline = make_simple_prepared_pipeline::<GoldilocksField>(f, LinkerMode::Bus);
test_mock_backend(pipeline.clone());
test_plonky3_pipeline(pipeline);

let pipeline = make_simple_prepared_pipeline::<BabyBearField>(f, LinkerMode::Bus);
test_plonky3_pipeline(pipeline);
}

#[test]
Expand All @@ -189,6 +192,9 @@ fn lookup_via_challenges() {
let pipeline = make_simple_prepared_pipeline::<GoldilocksField>(f, LinkerMode::Bus);
test_mock_backend(pipeline.clone());
test_plonky3_pipeline(pipeline);

let pipeline = make_simple_prepared_pipeline::<BabyBearField>(f, LinkerMode::Bus);
test_plonky3_pipeline(pipeline);
}

#[test]
Expand Down
82 changes: 82 additions & 0 deletions std/math/extension_field.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use std::field::known_field;
use std::field::KnownField;

use std::array::len;
use std::check::panic;

/// Whether we need to operate on an extension field (because the base field is too small).
let needs_extension: -> bool = || required_extension_size() > 1;

/// How many field elements / field extensions are recommended for the current base field.
let required_extension_size: -> int = || match known_field() {
Option::Some(KnownField::Goldilocks) => 2,
Option::Some(KnownField::BN254) => 1,
Option::Some(KnownField::BabyBear) => 4,
Option::Some(KnownField::KoalaBear) => 4,
None => panic("The permutation/lookup argument is not implemented for the current field!")
};

/// Wrapper around Fp2 and Fp4 to abstract which extension field is used.
/// Once PIL supports traits, we can remove this type and the functions below.
enum Ext<T> {
Fp2(std::math::fp2::Fp2<T>),
Fp4(std::math::fp4::Fp4<T>)
}

let<T: Add> add_ext: Ext<T>, Ext<T> -> Ext<T> = |a, b| match (a, b) {
(Ext::Fp2(aa), Ext::Fp2(bb)) => Ext::Fp2(std::math::fp2::add_ext(aa, bb)),
(Ext::Fp4(aa), Ext::Fp4(bb)) => Ext::Fp4(std::math::fp4::add_ext(aa, bb)),
_ => panic("Operands have different types")
};

let<T: Sub> sub_ext: Ext<T>, Ext<T> -> Ext<T> = |a, b| match (a, b) {
(Ext::Fp2(aa), Ext::Fp2(bb)) => Ext::Fp2(std::math::fp2::sub_ext(aa, bb)),
(Ext::Fp4(aa), Ext::Fp4(bb)) => Ext::Fp4(std::math::fp4::sub_ext(aa, bb)),
_ => panic("Operands have different types")
};

let<T: Add + FromLiteral + Mul> mul_ext: Ext<T>, Ext<T> -> Ext<T> = |a, b| match (a, b) {
(Ext::Fp2(aa), Ext::Fp2(bb)) => Ext::Fp2(std::math::fp2::mul_ext(aa, bb)),
(Ext::Fp4(aa), Ext::Fp4(bb)) => Ext::Fp4(std::math::fp4::mul_ext(aa, bb)),
_ => panic("Operands have different types")
};

let eval_ext: Ext<expr> -> Ext<fe> = query |a| match a {
Ext::Fp2(aa) => Ext::Fp2(std::math::fp2::eval_ext(aa)),
Ext::Fp4(aa) => Ext::Fp4(std::math::fp4::eval_ext(aa)),
};

let inv_ext: Ext<fe> -> Ext<fe> = query |a| match a {
Ext::Fp2(aa) => Ext::Fp2(std::math::fp2::inv_ext(aa)),
Ext::Fp4(aa) => Ext::Fp4(std::math::fp4::inv_ext(aa)),
};

let<T> unpack_ext_array: Ext<T> -> T[] = |a| match a {
Ext::Fp2(aa) => std::math::fp2::unpack_ext_array(aa),
Ext::Fp4(aa) => std::math::fp4::unpack_ext_array(aa),
};

let next_ext: Ext<expr> -> Ext<expr> = |a| match a {
Ext::Fp2(aa) => Ext::Fp2(std::math::fp2::next_ext(aa)),
Ext::Fp4(aa) => Ext::Fp4(std::math::fp4::next_ext(aa)),
};

let<T: FromLiteral> from_base: T -> Ext<T> = |x| match required_extension_size() {
1 => Ext::Fp2(std::math::fp2::from_base(x)),
2 => Ext::Fp2(std::math::fp2::from_base(x)),
4 => Ext::Fp4(std::math::fp4::from_base(x)),
_ => panic("Expected 1, 2, or 4")
};

let<T: FromLiteral> from_array: T[] -> Ext<T> = |arr| match len(arr) {
1 => Ext::Fp2(std::math::fp2::from_array(arr)),
2 => Ext::Fp2(std::math::fp2::from_array(arr)),
4 => Ext::Fp4(std::math::fp4::Fp4::Fp4(arr[0], arr[1], arr[2], arr[3])),
_ => panic("Expected 1, 2, or 4")
};

let constrain_eq_ext: Ext<expr>, Ext<expr> -> Constr[] = |a, b| match (a, b) {
(Ext::Fp2(aa), Ext::Fp2(bb)) => std::math::fp2::constrain_eq_ext(aa, bb),
(Ext::Fp4(aa), Ext::Fp4(bb)) => std::math::fp4::constrain_eq_ext(aa, bb),
_ => panic("Operands have different types")
};
28 changes: 6 additions & 22 deletions std/math/fp2.asm
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::convert::expr;
use std::field::known_field;
use std::field::KnownField;
use std::math::ff::inv_field;
use std::math::extension_field::needs_extension;
use std::prover::eval;

/// Corresponding Sage code to test irreducibility
Expand Down Expand Up @@ -139,31 +140,14 @@ let<T> unpack_ext_array: Fp2<T> -> T[] = |a| match a {
Fp2::Fp2(a0, a1) => [a0, a1]
};

/// Whether we need to operate on the F_{p^2} extension field (because the current field is too small).
let needs_extension: -> bool = || required_extension_size() > 1;

/// How many field elements / field extensions are recommended for the current base field.
let required_extension_size: -> int = || match known_field() {
Option::Some(KnownField::Goldilocks) => 2,
Option::Some(KnownField::BN254) => 1,
None => panic("The permutation/lookup argument is not implemented for the current field!")
};

/// Matches whether the length of a given array is correct to operate on the extension field
let is_extension = |arr| match len(arr) {
1 => false,
2 => true,
_ => panic("Expected 1 or 2 accumulator columns!")
};

/// Constructs an extension field element `a0 + a1 * X` from either `[a0, a1]` or `[a0]` (setting `a1`to zero in that case)
let fp2_from_array = |arr| {
if is_extension(arr) {
Fp2::Fp2(arr[0], arr[1])
} else {
let<T: FromLiteral> from_array: T[] -> Fp2<T> = |arr| match len(arr) {
1 => {
let _ = assert(!needs_extension(), || "The field is too small and needs to move to the extension field. Pass two elements instead!");
from_base(arr[0])
}
},
2 => Fp2::Fp2(arr[0], arr[1]),
_ => panic("Expected array of length 1 or 2")
};

mod test {
Expand Down
3 changes: 2 additions & 1 deletion std/math/mod.asm
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod ff;
mod fp2;
mod fp4;
mod fp4;
mod extension_field;
76 changes: 40 additions & 36 deletions std/protocols/bus.asm
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
use std::check::assert;
use std::array;
use std::math::fp2::Fp2;
use std::math::fp2::add_ext;
use std::math::fp2::sub_ext;
use std::math::fp2::mul_ext;
use std::math::fp2::inv_ext;
use std::math::fp2::eval_ext;
use std::math::fp2::unpack_ext_array;
use std::math::fp2::next_ext;
use std::math::fp2::from_base;
use std::math::fp2::needs_extension;
use std::math::fp2::fp2_from_array;
use std::math::fp2::constrain_eq_ext;
use std::math::extension_field::Ext;
use std::math::extension_field::add_ext;
use std::math::extension_field::sub_ext;
use std::math::extension_field::mul_ext;
use std::math::extension_field::inv_ext;
use std::math::extension_field::eval_ext;
use std::math::extension_field::unpack_ext_array;
use std::math::extension_field::next_ext;
use std::math::extension_field::from_base;
use std::math::extension_field::from_array;
use std::math::extension_field::constrain_eq_ext;
use std::protocols::fingerprint::fingerprint_with_id;
use std::protocols::fingerprint::fingerprint_with_id_inter;
use std::math::fp2::required_extension_size;
use std::math::extension_field::required_extension_size;
use std::prover::eval;
use std::field::known_field;
use std::field::KnownField;
Expand All @@ -32,48 +31,53 @@ use std::check::panic;
/// - latch: a binary expression which indicates where the multiplicity can be non-zero.
let bus_interaction: expr, expr[], expr, expr -> () = constr |id, tuple, multiplicity, latch| {

std::check::assert(required_extension_size() <= 2, || "Invalid extension size");

// Add phantom bus interaction
let full_tuple = [id] + tuple;
Constr::PhantomBusInteraction(multiplicity, full_tuple, latch);

let extension_field_size = required_extension_size();

// Alpha is used to compress the LHS and RHS arrays.
let alpha = fp2_from_array(array::new(required_extension_size(), |i| challenge(0, i + 1)));
let alpha = from_array(array::new(extension_field_size, |i| challenge(0, i + 1)));
// Beta is used to update the accumulator.
let beta = fp2_from_array(array::new(required_extension_size(), |i| challenge(0, i + 3)));
let beta = from_array(array::new(extension_field_size, |i| challenge(0, i + 1 + extension_field_size)));

// Implemented as: folded = (beta - fingerprint(id, tuple...));
let folded = match known_field() {
Option::Some(KnownField::Goldilocks) => {
// Materialized as a witness column for two reasons:
// - It makes sure the constraint degree is independent of the input tuple.
// - We can access folded', even if the tuple contains next references.
// Note that if all expressions are degree-1 and there is no next reference,
// this is wasteful, but we can't check that here.
let folded = fp2_from_array(
array::new(required_extension_size(),
|i| std::prover::new_witness_col_at_stage("folded", 1))
);
constrain_eq_ext(folded, sub_ext(beta, fingerprint_with_id_inter(id, tuple, alpha)));
folded
},
let materialize_folded = match known_field() {
// Materialized as a witness column for two reasons:
// - It makes sure the constraint degree is independent of the input tuple.
// - We can access folded', even if the tuple contains next references.
// Note that if all expressions are degree-1 and there is no next reference,
// this is wasteful, but we can't check that here.
Option::Some(KnownField::Goldilocks) => true,
Option::Some(KnownField::BabyBear) => true,
Option::Some(KnownField::KoalaBear) => true,
// The case above triggers our hand-written witness generation, but on Bn254, we'd not be
// on the extension field and use the automatic witness generation.
// However, it does not work with a materialized folded tuple. At the same time, Halo2
// (the only prover that supports BN254) does not have a hard degree bound. So, we can
// in-line the expression here.
Option::Some(KnownField::BN254) => sub_ext(beta, fingerprint_with_id_inter(id, tuple, alpha)),
// in-line the expression here.
Option::Some(KnownField::BN254) => false,
_ => panic("Unexpected field!")
};
let folded = if materialize_folded {
let folded = from_array(
array::new(extension_field_size,
|i| std::prover::new_witness_col_at_stage("folded", 1))
);
constrain_eq_ext(folded, sub_ext(beta, fingerprint_with_id_inter(id, tuple, alpha)));
folded
} else {
sub_ext(beta, fingerprint_with_id_inter(id, tuple, alpha))
};

let folded_next = next_ext(folded);

let m_ext = from_base(multiplicity);
let m_ext_next = next_ext(m_ext);

let acc = array::new(required_extension_size(), |i| std::prover::new_witness_col_at_stage("acc", 1));
let acc_ext = fp2_from_array(acc);
let acc = array::new(extension_field_size, |i| std::prover::new_witness_col_at_stage("acc", 1));
let acc_ext = from_array(acc);
let next_acc = next_ext(acc_ext);

let is_first: col = std::well_known::is_first;
Expand All @@ -94,7 +98,7 @@ let bus_interaction: expr, expr[], expr, expr -> () = constr |id, tuple, multipl
/// using extension field arithmetic.
/// This is intended to be used as a hint in the extension field case; for the base case
/// automatic witgen is smart enough to figure out the value of the accumulator.
let compute_next_z: expr, expr, expr[], expr, Fp2<expr>, Fp2<expr>, Fp2<expr> -> fe[] = query |is_first, id, tuple, multiplicity, acc, alpha, beta| {
let compute_next_z: expr, expr, expr[], expr, Ext<expr>, Ext<expr>, Ext<expr> -> fe[] = query |is_first, id, tuple, multiplicity, acc, alpha, beta| {

let m_next = eval(multiplicity');
let m_ext_next = from_base(m_next);
Expand Down
49 changes: 31 additions & 18 deletions std/protocols/fingerprint.asm
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
use std::array;
use std::array::len;
use std::math::fp2::Fp2;
use std::math::fp2::add_ext;
use std::math::fp2::mul_ext;
use std::math::fp2::pow_ext;
use std::math::fp2::from_base;
use std::math::fp2::eval_ext;
use std::math::extension_field::Ext;
use std::math::extension_field::add_ext;
use std::math::extension_field::mul_ext;
use std::math::extension_field::from_base;
use std::math::extension_field::eval_ext;
use std::check::assert;

/// Maps [x_1, x_2, ..., x_n] to its Read-Solomon fingerprint, using a challenge alpha: $\sum_{i=1}^n alpha**{(n - i)} * x_i$
/// To generate an expression that computes the fingerprint, use `fingerprint_inter` instead.
/// Note that alpha is passed as an expressions, so that it is only evaluated if needed (i.e., if len(expr_array) > 1).
let fingerprint: fe[], Fp2<expr> -> Fp2<fe> = query |expr_array, alpha| if array::len(expr_array) == 1 {
let fingerprint: fe[], Ext<expr> -> Ext<fe> = query |expr_array, alpha| if array::len(expr_array) == 1 {
// No need to evaluate `alpha` (which would be removed by the optimizer).
from_base(expr_array[0])
} else {
fingerprint_impl(expr_array, eval_ext(alpha), len(expr_array))
};

let fingerprint_impl: fe[], Fp2<fe>, int -> Fp2<fe> = query |expr_array, alpha, l| if l == 1 {
let fingerprint_impl: fe[], Ext<fe>, int -> Ext<fe> = query |expr_array, alpha, l| if l == 1 {
// Base case
from_base(expr_array[0])
} else {
Expand All @@ -30,43 +29,57 @@ let fingerprint_impl: fe[], Fp2<fe>, int -> Fp2<fe> = query |expr_array, alpha,

/// Like `fingerprint`, but "materializes" the intermediate results as intermediate columns.
/// Inlining them would lead to an exponentially-sized expression.
let fingerprint_inter: expr[], Fp2<expr> -> Fp2<expr> = |expr_array, alpha| if len(expr_array) == 1 {
let fingerprint_inter: expr[], Ext<expr> -> Ext<expr> = |expr_array, alpha| if len(expr_array) == 1 {
// Base case
from_base(expr_array[0])
} else {
assert(len(expr_array) > 1, || "fingerprint requires at least one element");

// Recursively compute the fingerprint as fingerprint(expr_array[:-1], alpha) * alpha + expr_array[-1]
let intermediate_fingerprint = match fingerprint_inter(array::sub_array(expr_array, 0, len(expr_array) - 1), alpha) {
Fp2::Fp2(a0, a1) => {
Ext::Fp2(std::math::fp2::Fp2::Fp2(a0, a1)) => {
let intermediate_fingerprint_0: inter = a0;
let intermediate_fingerprint_1: inter = a1;
Fp2::Fp2(intermediate_fingerprint_0, intermediate_fingerprint_1)
Ext::Fp2(std::math::fp2::Fp2::Fp2(intermediate_fingerprint_0, intermediate_fingerprint_1))
},
Ext::Fp4(std::math::fp4::Fp4::Fp4(a0, a1, a2, a3)) => {
let intermediate_fingerprint_0: inter = a0;
let intermediate_fingerprint_1: inter = a1;
let intermediate_fingerprint_2: inter = a2;
let intermediate_fingerprint_3: inter = a3;
Ext::Fp4(std::math::fp4::Fp4::Fp4(intermediate_fingerprint_0, intermediate_fingerprint_1, intermediate_fingerprint_2, intermediate_fingerprint_3))
}
};
add_ext(mul_ext(alpha, intermediate_fingerprint), from_base(expr_array[len(expr_array) - 1]))
};

/// Maps [id, x_1, x_2, ..., x_n] to its Read-Solomon fingerprint, using a challenge alpha: $\sum_{i=1}^n alpha**{(n - i)} * x_i$
let fingerprint_with_id: fe, fe[], Fp2<expr> -> Fp2<fe> = query |id, expr_array, alpha| fingerprint([id] + expr_array, alpha);
let fingerprint_with_id: fe, fe[], Ext<expr> -> Ext<fe> = query |id, expr_array, alpha| fingerprint([id] + expr_array, alpha);

/// Maps [id, x_1, x_2, ..., x_n] to its Read-Solomon fingerprint, using a challenge alpha: $\sum_{i=1}^n alpha**{(n - i)} * x_i$
let fingerprint_with_id_inter: expr, expr[], Fp2<expr> -> Fp2<expr> = |id, expr_array, alpha| fingerprint_inter([id] + expr_array, alpha);
let fingerprint_with_id_inter: expr, expr[], Ext<expr> -> Ext<expr> = |id, expr_array, alpha| fingerprint_inter([id] + expr_array, alpha);

mod test {
use super::fingerprint;
use std::check::assert;
use std::math::fp2::Fp2;
use std::math::fp2::from_base;
use std::math::extension_field::Ext;
use std::math::extension_field::from_base;
use std::check::panic;

/// Helper function to assert that the fingerprint of a tuple is equal to the expected value.
let assert_fingerprint_equal: fe[], expr, fe -> () = query |tuple, challenge, expected| {
let result = fingerprint(tuple, from_base(challenge));
match result {
Fp2::Fp2(actual, should_be_zero) => {
assert(should_be_zero == 0, || "Returned an extension field element");
Ext::Fp2(std::math::fp2::Fp2::Fp2(actual, zero)) => {
assert(zero == 0, || "Returned an extension field element");
assert(expected == actual, || "expected != actual");
},
Ext::Fp4(std::math::fp4::Fp4::Fp4(actual, zero1, zero2, zero3)) => {
assert(zero1 == 0, || "Returned an extension field element");
assert(zero2 == 0, || "Returned an extension field element");
assert(zero3 == 0, || "Returned an extension field element");
assert(expected == actual, || "expected != actual");
}
},
}
};

Expand Down
Loading

0 comments on commit 315592b

Please sign in to comment.