Skip to content

Commit

Permalink
perf: batch LHS of ECIP check. (keep-starknet-strange#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime authored Dec 6, 2024
1 parent 8023ef6 commit b44d00d
Show file tree
Hide file tree
Showing 27 changed files with 9,936 additions and 24,200 deletions.
14 changes: 9 additions & 5 deletions hydra/garaga/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,12 +1043,16 @@ def degrees_infos(self) -> dict[str, dict[str, int]]:
"b": self.b.degrees_infos(),
}

def validate_degrees(self, msm_size: int) -> bool:
def validate_degrees(self, msm_size: int, batched: bool = True) -> bool:
degrees = self.degrees_infos()
assert degrees["a"]["numerator"] <= msm_size + 1
assert degrees["a"]["denominator"] <= msm_size + 2
assert degrees["b"]["numerator"] <= msm_size + 2
assert degrees["b"]["denominator"] <= msm_size + 5
if batched:
extra = 2
else:
extra = 0
assert degrees["a"]["numerator"] <= msm_size + 1 + extra
assert degrees["a"]["denominator"] <= msm_size + 2 + extra
assert degrees["b"]["numerator"] <= msm_size + 2 + extra
assert degrees["b"]["denominator"] <= msm_size + 5 + extra
return True

def print_as_sage_poly(self, var: str = "x", as_hex: bool = False) -> str:
Expand Down
26 changes: 26 additions & 0 deletions hydra/garaga/hints/ecip.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,32 @@ def print_ff(ff: FF):
return string


def n_points_from_n_coeffs(n_coeffs: int, batched: bool) -> int:
if batched:
extra = 4 * 2
else:
extra = 0

# n_coeffs = 10 + 4n_points => 4n_points = n_coeffs - 10
assert n_coeffs >= 10 + extra
assert (n_coeffs - 10 - extra) % 4 == 0
return (n_coeffs - 10 - extra) // 4


def n_coeffs_from_n_points(n_points: int, batched: bool) -> tuple[int, int, int, int]:
if batched:
extra = 2
else:
extra = 0

return (
1 + n_points + extra,
1 + n_points + 1 + extra,
1 + n_points + 1 + extra,
1 + n_points + 4 + extra,
)


if __name__ == "__main__":
import random

Expand Down
26 changes: 17 additions & 9 deletions hydra/garaga/hints/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def fill_uint256(x: int, ids: object):


def padd_function_felt(
f: FunctionFelt, n: int, py_felt: bool = False
f: FunctionFelt, n: int, py_felt: bool = False, batched: bool = False
) -> tuple[list[int], list[int], list[int], list[int]]:
a_num = f.a.numerator.get_coeffs() if py_felt else f.a.numerator.get_value_coeffs()
a_den = (
Expand All @@ -298,15 +298,23 @@ def padd_function_felt(
b_den = (
f.b.denominator.get_coeffs() if py_felt else f.b.denominator.get_value_coeffs()
)
assert len(a_num) <= n + 1
assert len(a_den) <= n + 2
assert len(b_num) <= n + 2
assert len(b_den) <= n + 5
assert len(a_num) <= n + 1 + (
2 if batched else 0
), f"a_num has {len(a_num)} limbs, expected at most {n + 1 + (2 if batched else 0)}"
assert len(a_den) <= n + 2 + (
2 if batched else 0
), f"a_den has {len(a_den)} limbs, expected at most {n + 2 + (2 if batched else 0)}"
assert len(b_num) <= n + 2 + (
2 if batched else 0
), f"b_num has {len(b_num)} limbs, expected at most {n + 2 + (2 if batched else 0)}"
assert len(b_den) <= n + 5 + (
2 if batched else 0
), f"b_den has {len(b_den)} limbs, expected at most {n + 5 + (2 if batched else 0)}"
zero = [f.a.numerator.field.zero()] if py_felt else [0]
a_num = a_num + zero * (n + 1 - len(a_num))
a_den = a_den + zero * (n + 2 - len(a_den))
b_num = b_num + zero * (n + 2 - len(b_num))
b_den = b_den + zero * (n + 5 - len(b_den))
a_num = a_num + zero * (n + 1 + (2 if batched else 0) - len(a_num))
a_den = a_den + zero * (n + 2 + (2 if batched else 0) - len(a_den))
b_num = b_num + zero * (n + 2 + (2 if batched else 0) - len(b_num))
b_den = b_den + zero * (n + 5 + (2 if batched else 0) - len(b_den))
return (a_num, a_den, b_num, b_den)


Expand Down
4 changes: 2 additions & 2 deletions hydra/garaga/modulo_circuit_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,10 +503,10 @@ def struct_name(self) -> str:

@staticmethod
def from_FunctionFelt(
name: str, f: FunctionFelt, msm_size: int
name: str, f: FunctionFelt, msm_size: int, batched: bool = False
) -> "FunctionFeltCircuit":
_a_num, _a_den, _b_num, _b_den = io.padd_function_felt(
f, msm_size, py_felt=True
f, msm_size, py_felt=True, batched=batched
)
return FunctionFeltCircuit(
name=name,
Expand Down
7 changes: 5 additions & 2 deletions hydra/garaga/precompiled_circuits/all_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,15 @@ class CircuitID(Enum):
},
CircuitID.EVAL_FUNCTION_CHALLENGE_DUPL: {
"class": EvalFunctionChallengeDuplCircuit,
"params": [{"n_points": k} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],
"params": [
{"n_points": k, "batched": True} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
]
+ [{"n_points": k, "batched": False} for k in [1, 2]],
"filename": "ec",
},
CircuitID.INIT_FUNCTION_CHALLENGE_DUPL: {
"class": InitFunctionChallengeDuplCircuit,
"params": [{"n_points": k} for k in [11]],
"params": [{"n_points": k, "batched": True} for k in [11]],
"filename": "ec",
},
CircuitID.ACC_FUNCTION_CHALLENGE_DUPL: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import garaga.modulo_circuit_structs as structs
from garaga.definitions import CURVES, CurveID, G1Point, G2Point
from garaga.hints import neg_3
from garaga.hints.ecip import slope_intercept
from garaga.hints.ecip import (
n_coeffs_from_n_points,
n_points_from_n_coeffs,
slope_intercept,
)
from garaga.modulo_circuit import WriteOps
from garaga.modulo_circuit_structs import G1PointCircuit, G2PointCircuit, u384
from garaga.precompiled_circuits.compilable_circuits.base import (
Expand Down Expand Up @@ -398,26 +402,19 @@ def __init__(
n_points: int = 1,
auto_run: bool = True,
compilation_mode: int = 0,
batched: bool = False,
generic_circuit: bool = True,
) -> None:
self.n_points = n_points
self.batched = batched
self.generic_circuit = generic_circuit
super().__init__(
name=f"eval_fn_challenge_dupl_{n_points}P",
name=f"eval_fn_challenge_dupl_{n_points}P" + ("_rlc" if batched else ""),
curve_id=curve_id,
auto_run=auto_run,
compilation_mode=compilation_mode,
)

@staticmethod
def _n_coeffs_from_n_points(n_points: int) -> tuple[int, int, int, int]:
return (1 + n_points, 1 + n_points + 1, 1 + n_points + 1, 1 + n_points + 4)

@staticmethod
def _n_points_from_n_coeffs(n_coeffs: int) -> int:
# n_coeffs = 10 + 4n_points => 4n_points = n_coeffs - 10
assert n_coeffs >= 10
assert (n_coeffs - 10) % 4 == 0
return (n_coeffs - 10) // 4

def build_input(self) -> list[PyFelt]:
input = []
circuit = SlopeInterceptSamePointCircuit(self.curve_id, auto_run=False)
Expand All @@ -426,14 +423,17 @@ def build_input(self) -> list[PyFelt]:
[xA, _yA, _A]
).output
input.extend([xA0.felt, _yA.felt, xA2.felt, yA2.felt, coeff0.felt, coeff2.felt])
n_coeffs = self._n_coeffs_from_n_points(self.n_points)
n_coeffs = n_coeffs_from_n_points(self.n_points, self.batched)
for _ in range(sum(n_coeffs)):
input.append(self.field(randint(0, CURVES[self.curve_id].p - 1)))
return input

def _run_circuit_inner(self, input: list[PyFelt]) -> ModuloCircuit:
circuit = ECIPCircuits(
self.name, self.curve_id, compilation_mode=self.compilation_mode
self.name,
self.curve_id,
compilation_mode=self.compilation_mode,
generic_circuit=self.generic_circuit,
)

xA0, yA0 = circuit.write_struct(
Expand All @@ -454,14 +454,14 @@ def split_list(input_list, lengths):
start_idx += length
return result

n_points = self._n_points_from_n_coeffs(len(all_coeffs))
n_points = n_points_from_n_coeffs(len(all_coeffs), self.batched)
_log_div_a_num, _log_div_a_den, _log_div_b_num, _log_div_b_den = split_list(
all_coeffs, self._n_coeffs_from_n_points(n_points)
all_coeffs, n_coeffs_from_n_points(n_points, self.batched)
)
log_div_a_num, log_div_a_den, log_div_b_num, log_div_b_den = (
circuit.write_struct(
structs.FunctionFeltCircuit(
name="SumDlogDiv",
name="SumDlogDiv" + ("Batched" if self.batched else ""),
elmts=[
structs.u384Span("log_div_a_num", _log_div_a_num),
structs.u384Span("log_div_a_den", _log_div_a_den),
Expand Down Expand Up @@ -494,31 +494,22 @@ def __init__(
curve_id: int,
n_points: int = 1,
auto_run: bool = True,
batched: bool = False,
compilation_mode: int = 0,
) -> None:
self.n_points = n_points
self.batched = batched
super().__init__(
name=f"init_fn_challenge_dupl_{n_points}P",
name=f"init_fn_challenge_dupl_{n_points}P" + ("_rlc" if batched else ""),
curve_id=curve_id,
auto_run=auto_run,
compilation_mode=compilation_mode,
)

@staticmethod
def _n_coeffs_from_n_points(n_points: int) -> tuple[int, int, int, int]:
return (1 + n_points, 1 + n_points + 1, 1 + n_points + 1, 1 + n_points + 4)

@staticmethod
def _n_points_from_n_coeffs(n_coeffs: int) -> int:
# n_coeffs = 10 + 4n_points => 4n_points = n_coeffs - 10
assert n_coeffs >= 10
assert (n_coeffs - 10) % 4 == 0
return (n_coeffs - 10) // 4

def build_input(self) -> list[PyFelt]:
input = []
input.extend([self.field.random(), self.field.random()]) # xA0, xA2
n_coeffs = self._n_coeffs_from_n_points(self.n_points)
n_coeffs = n_coeffs_from_n_points(self.n_points, self.batched)
for _ in range(sum(n_coeffs)):
input.append(self.field(randint(0, CURVES[self.curve_id].p - 1)))
return input
Expand All @@ -539,9 +530,9 @@ def split_list(input_list, lengths):
start_idx += length
return result

n_points = self._n_points_from_n_coeffs(len(all_coeffs))
n_points = n_points_from_n_coeffs(len(all_coeffs), self.batched)
_log_div_a_num, _log_div_a_den, _log_div_b_num, _log_div_b_den = split_list(
all_coeffs, self._n_coeffs_from_n_points(n_points)
all_coeffs, n_coeffs_from_n_points(n_points, self.batched)
)

log_div_a_num, log_div_a_den, log_div_b_num, log_div_b_den = (
Expand Down
10 changes: 8 additions & 2 deletions hydra/garaga/precompiled_circuits/ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,17 @@ def _derive_point_from_x(


class ECIPCircuits(ModuloCircuit):
def __init__(self, name: str, curve_id: int, compilation_mode: int = 0):
def __init__(
self,
name: str,
curve_id: int,
compilation_mode: int = 0,
generic_circuit: bool = True,
):
super().__init__(
name=name,
curve_id=curve_id,
generic_circuit=True,
generic_circuit=generic_circuit,
compilation_mode=compilation_mode,
)
self.curve = CURVES[curve_id]
Expand Down
5 changes: 2 additions & 3 deletions hydra/garaga/starknet/groth16_contract_generator/calldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def groth16_calldata_from_vk_and_proof(
vk: Groth16VerifyingKey, proof: Groth16Proof, use_rust: bool = True
vk: Groth16VerifyingKey, proof: Groth16Proof, use_rust: bool = False
) -> list[int]:
if use_rust:
return _groth16_calldata_from_vk_and_proof_rust(vk, proof)
Expand Down Expand Up @@ -45,13 +45,13 @@ def groth16_calldata_from_vk_and_proof(
curve_id=vk.curve_id,
points=[vk.ic[3], vk.ic[4]],
scalars=[proof.public_inputs[2], proof.public_inputs[3]],
risc0_mode=True,
)
calldata.extend(
msm.serialize_to_calldata(
include_digits_decomposition=True,
include_points_and_scalars=False,
serialize_as_pure_felt252_array=True,
risc0_mode=True,
)
)
else:
Expand All @@ -66,7 +66,6 @@ def groth16_calldata_from_vk_and_proof(
include_digits_decomposition=True,
include_points_and_scalars=False,
serialize_as_pure_felt252_array=True,
risc0_mode=False,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from garaga.starknet.cli.utils import create_directory, get_package_version
from garaga.starknet.groth16_contract_generator.parsing_utils import Groth16VerifyingKey

ECIP_OPS_CLASS_HASH = 0x70C1D1C709C75E3CF51D79D19CF7C84A0D4521F3A2B8BF7BFF5CB45EE0DD289
ECIP_OPS_CLASS_HASH = 0x223A0051C2E31EDE1FD33DB4F01BC979901FD80F3429017710176CCE6AADA3B


def precompute_lines_from_vk(vk: Groth16VerifyingKey) -> StructArray:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def build_hash_to_curve_hint(message: bytes) -> HashToCurveHint:
# print(f"cofactor: {cofactor}, hex :{hex(cofactor)}")

msm_builder = MSMCalldataBuilder(
curve_id=CurveID.BLS12_381, points=[sum_pt], scalars=[cofactor]
curve_id=CurveID.BLS12_381, points=[sum_pt], scalars=[cofactor], risc0_mode=True
)
msm_hint, derive_point_from_x_hint = msm_builder.build_msm_hints(risc0_mode=True)
msm_hint, derive_point_from_x_hint = msm_builder.build_msm_hints()

return HashToCurveHint(
f0_hint=f0_hint,
Expand Down
Loading

0 comments on commit b44d00d

Please sign in to comment.