Skip to content

Commit

Permalink
Rust (+ Wasm) MultiPairingCheck calldata builder (keep-starknet-stran…
Browse files Browse the repository at this point in the history
  • Loading branch information
raugfer authored Oct 3, 2024
1 parent 47efa22 commit 1ed6697
Show file tree
Hide file tree
Showing 24 changed files with 4,891 additions and 124 deletions.
21 changes: 20 additions & 1 deletion hydra/garaga/starknet/tests_and_calldata_generators/mpcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from functools import lru_cache

from garaga import garaga_rs
from garaga import modulo_circuit_structs as structs
from garaga.algebra import Polynomial, PyFelt
from garaga.definitions import CurveID, G1G2Pair, get_base_field, get_irreducible_poly
Expand Down Expand Up @@ -325,7 +326,25 @@ def to_cairo_1_test(self):
"""
return code

def serialize_to_calldata(self) -> list[int]:
def _serialize_to_calldata_rust(self) -> list[int]:
return garaga_rs.mpc_calldata_builder(
self.curve_id.value,
[element.value for pair in self.pairs for element in pair.to_pyfelt_list()],
self.n_fixed_g2,
(
[element.value for element in self.public_pair.to_pyfelt_list()]
if self.public_pair is not None
else []
),
)

def serialize_to_calldata(
self,
use_rust=False,
) -> list[int]:
if use_rust:
return self._serialize_to_calldata_rust()

mpcheck_hint, small_Q = self.build_mpcheck_hint()

call_data: list[int] = []
Expand Down
35 changes: 34 additions & 1 deletion tests/hydra/starknet/test_calldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,47 @@
import pytest

from garaga.definitions import CURVES, CurveID, G1Point
from garaga.precompiled_circuits.multi_pairing_check import get_pairing_check_input
from garaga.starknet.tests_and_calldata_generators.mpcheck import MPCheckCalldataBuilder
from garaga.starknet.tests_and_calldata_generators.msm import MSMCalldataBuilder

# Define the curves to be tested
curves = list(CurveID)


@pytest.mark.parametrize("curve_id", [CurveID.BN254, CurveID.BLS12_381])
@pytest.mark.parametrize("mpc_size", [2, 3])
@pytest.mark.parametrize("n_fixed_g2", [2])
@pytest.mark.parametrize("include_m", [True, False])
def test_mpc_calldata_builder(
curve_id,
mpc_size,
n_fixed_g2,
include_m,
):
pairs, public_pair = get_pairing_check_input(
curve_id=curve_id,
n_pairs=mpc_size,
include_m=include_m,
return_pairs=True,
)

mpc = MPCheckCalldataBuilder(
curve_id=curve_id,
pairs=pairs,
n_fixed_g2=n_fixed_g2,
public_pair=public_pair,
)

calldata1 = mpc.serialize_to_calldata(use_rust=False)

calldata2 = mpc.serialize_to_calldata(use_rust=True)

assert calldata1 == calldata2


@pytest.mark.parametrize("curve_id", curves)
@pytest.mark.parametrize("msm_size", range(1, 2))
@pytest.mark.parametrize("msm_size", [1, 2])
@pytest.mark.parametrize("include_digits_decomposition", [True, False])
@pytest.mark.parametrize("include_points_and_scalars", [True, False])
@pytest.mark.parametrize("serialize_as_pure_felt252_array", [True, False])
Expand Down
Loading

0 comments on commit 1ed6697

Please sign in to comment.