Skip to content

Commit

Permalink
perf: BN254 Miller Loop "10" "01" compression (keep-starknet-strange#261
Browse files Browse the repository at this point in the history
)
  • Loading branch information
feltroidprime authored Dec 3, 2024
1 parent 5f3b232 commit 1837724
Show file tree
Hide file tree
Showing 18 changed files with 7,444 additions and 11,926 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/maturin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ jobs:
strategy:
matrix:
platform:
- runner: macos-12
- runner: macos-14
target: x86_64
- runner: macos-14
target: aarch64
Expand Down
32 changes: 31 additions & 1 deletion hydra/garaga/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,5 +1227,35 @@ def replace_consecutive_zeros(lst):
return result


def recode_naf_bits(lst):
result = []
i = 0
while i < len(lst):
if i < len(lst) - 1 and lst[i] == 0 and (lst[i + 1] == 1 or lst[i + 1] == -1):
# "01" or "0-1"
if lst[i + 1] == 1:
result.append(3) # Replace "01" with 3
else:
result.append(4) # Replace "0-1" with 4
i += 2
elif i < len(lst) - 1 and (lst[i] == 1 or lst[i] == -1) and lst[i + 1] == 0:
# "10" or "-10"
if lst[i] == 1:
result.append(1) # Replace 10 with 6
else:
result.append(2) # Replace -10 with 7
i += 2
elif i < len(lst) - 1 and lst[i] == 0 and lst[i + 1] == 0:
result.append(0) # Replace consecutive zeros with 0
i += 2
else:
raise ValueError(f"Unexpected bit sequence at index {i}")
return result


if __name__ == "__main__":
pass
r = recode_naf_bits(jy00(6 * 0x44E992B44A6909F1 + 2)[2:])
print(r, len(r))

# bls = [int(x) for x in bin(0xD201000000010000)[2:]][2:]
# recode_naf_bits(bls)
26 changes: 24 additions & 2 deletions hydra/garaga/precompiled_circuits/all_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
FixedG2MPCheckBit0,
FixedG2MPCheckBit00,
FixedG2MPCheckBit1,
FixedG2MPCheckBit01,
FixedG2MPCheckBit10,
FixedG2MPCheckFinalizeBN,
FixedG2MPCheckInitBit,
FP12MulAssertOne,
Expand Down Expand Up @@ -90,6 +92,8 @@ class CircuitID(Enum):
MP_CHECK_BIT0_LOOP = int.from_bytes(b"mp_check_bit0_loop", "big")
MP_CHECK_BIT00_LOOP = int.from_bytes(b"mp_check_bit00_loop", "big")
MP_CHECK_BIT1_LOOP = int.from_bytes(b"mp_check_bit1_loop", "big")
MP_CHECK_BIT01_LOOP = int.from_bytes(b"mp_check_bit01_loop", "big")
MP_CHECK_BIT10_LOOP = int.from_bytes(b"mp_check_bit10_loop", "big")
MP_CHECK_PREPARE_PAIRS = int.from_bytes(b"mp_check_prepare_pairs", "big")
MP_CHECK_PREPARE_LAMBDA_ROOT = int.from_bytes(
b"mp_check_prepare_lambda_root", "big"
Expand Down Expand Up @@ -206,7 +210,7 @@ class CircuitID(Enum):
{"n_pairs": 3, "n_fixed_g2": 2}, # Groth16
],
"filename": "multi_pairing_check",
"curve_ids": [CurveID.BN254, CurveID.BLS12_381],
"curve_ids": [CurveID.BLS12_381],
},
CircuitID.MP_CHECK_BIT00_LOOP: {
"class": FixedG2MPCheckBit00,
Expand All @@ -224,7 +228,25 @@ class CircuitID(Enum):
{"n_pairs": 3, "n_fixed_g2": 2}, # Groth16
],
"filename": "multi_pairing_check",
"curve_ids": [CurveID.BN254, CurveID.BLS12_381],
"curve_ids": [CurveID.BLS12_381],
},
CircuitID.MP_CHECK_BIT01_LOOP: {
"class": FixedG2MPCheckBit01,
"params": [
{"n_pairs": 2, "n_fixed_g2": 2}, # BLS SIG / KZG Verif
{"n_pairs": 3, "n_fixed_g2": 2}, # Groth16
],
"filename": "multi_pairing_check",
"curve_ids": [CurveID.BN254],
},
CircuitID.MP_CHECK_BIT10_LOOP: {
"class": FixedG2MPCheckBit10,
"params": [
{"n_pairs": 2, "n_fixed_g2": 2}, # BLS SIG / KZG Verif
{"n_pairs": 3, "n_fixed_g2": 2}, # Groth16
],
"filename": "multi_pairing_check",
"curve_ids": [CurveID.BN254],
},
CircuitID.MP_CHECK_PREPARE_PAIRS: {
"class": MPCheckPreparePairs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,31 @@ def input_map(

def _base_input_map(self, bit_type: str) -> dict:
"""
Base input map for the bit 0, 1, and 00 cases.
Base input map for the bit 0, 1, 00, 01, and 10 cases.
"""
input_map = {}

# Add pair inputs
for k in range(self.n_fixed_g2):
input_map[f"yInv_{k}"] = u384
input_map[f"xNegOverY_{k}"] = u384
input_map[f"G2_line_{k}"] = G2Line
if bit_type == "1":
input_map[f"Q_or_Q_neg_line{k}"] = G2Line
input_map[f"G2_line_dbl_{k}"] = G2Line
if bit_type in ("1"):
input_map[f"G2_line_add_{k}"] = G2Line
if bit_type == "10":
input_map[f"G2_line_add_1_{k}"] = G2Line
input_map[f"G2_line_dbl_0_{k}"] = G2Line
if bit_type == "01":
input_map[f"G2_line_dbl_1{k}"] = G2Line
input_map[f"G2_line_add_1{k}"] = G2Line
if bit_type == "00":
input_map[f"G2_line_2nd_0_{k}"] = G2Line

for k in range(self.n_fixed_g2, self.n_pairs):
input_map[f"yInv_{k}"] = u384
input_map[f"xNegOverY_{k}"] = u384
input_map[f"Q_{k}"] = G2PointCircuit
if bit_type == "1":
if bit_type in ("1", "01", "10"):
input_map[f"Q_or_Q_neg_{k}"] = G2PointCircuit

# Add common inputs
Expand All @@ -147,7 +153,7 @@ def _base_input_map(self, bit_type: str) -> dict:
input_map["f_i_plus_one_of_z"] = u384

# Add bit-specific inputs
if bit_type == "1":
if bit_type in ("1", "01", "10"):
input_map["c_or_cinv_of_z"] = u384

input_map["z"] = u384
Expand Down Expand Up @@ -235,12 +241,15 @@ def _execute_circuit_logic(self, circuit, vars) -> ModuloCircuit:
Implement the circuit logic using the processed input variables.
"""

def _execute_circuit_bit_logic_base(self, circuit, vars, bit_type):
def _execute_circuit_bit_logic_base(self, circuit: ModuloCircuit, vars, bit_type):
n_pairs = self.n_pairs
assert n_pairs >= 2, f"n_pairs must be >= 2, got {n_pairs}"

current_points, q_or_q_neg_points = parse_precomputed_g1_consts_and_g2_points(
circuit, vars, n_pairs, bit_1=(bit_type == "1")
circuit,
vars,
n_pairs,
bit_1=(bit_type == "1" or bit_type == "01" or bit_type == "10"),
)

circuit.create_lines_z_powers(vars["z"])
Expand All @@ -256,8 +265,12 @@ def _execute_circuit_bit_logic_base(self, circuit, vars, bit_type):
circuit, current_points, q_or_q_neg_points, sum_i_prod_k_P, bit_type
)

if bit_type == "1":
if bit_type in ("1", "01"):
sum_i_prod_k_P = circuit.mul(sum_i_prod_k_P, vars["c_or_cinv_of_z"])
elif bit_type == "10":
sum_i_prod_k_P = circuit.mul(
sum_i_prod_k_P, circuit.square(vars["c_or_cinv_of_z"])
)

f_i_plus_one_of_z = vars["f_i_plus_one_of_z"]
new_lhs = circuit.mul(
Expand Down Expand Up @@ -304,6 +317,55 @@ def _process_points(
)
new_new_points.append(T)
return new_new_points, sum_i_prod_k_P
elif bit_type == "01":
for k in range(self.n_pairs):
T, l1 = circuit.double_step(current_points[k], k)
sum_i_prod_k_P = self._multiply_line_evaluations(
circuit, sum_i_prod_k_P, [l1], k
)
new_points.append(T)

sum_i_prod_k_P = circuit.mul(
sum_i_prod_k_P,
sum_i_prod_k_P,
"Compute (f^2 * Π(i,k) (line_i,k(z))) ^ 2 = f^4 * (Π(i,k) (line_i,k(z)))^2",
)

new_new_points = []
for k in range(self.n_pairs):
T, l1, l2 = circuit.double_and_add_step(
new_points[k], q_or_q_neg_points[k], k
)
sum_i_prod_k_P = self._multiply_line_evaluations(
circuit, sum_i_prod_k_P, [l1, l2], k
)
new_new_points.append(T)

return new_new_points, sum_i_prod_k_P

elif bit_type == "10":
for k in range(self.n_pairs):
T, l1, l2 = circuit.double_and_add_step(
current_points[k], q_or_q_neg_points[k], k
)
sum_i_prod_k_P = self._multiply_line_evaluations(
circuit, sum_i_prod_k_P, [l1, l2], k
)
new_points.append(T)

sum_i_prod_k_P = circuit.mul(
sum_i_prod_k_P,
sum_i_prod_k_P,
"Compute (f^2 * Π(i,k) (line_i,k(z))) ^ 2 = f^4 * (Π(i,k) (line_i,k(z)))^2",
)
new_new_points = []
for k in range(self.n_pairs):
T, l1 = circuit.double_step(new_points[k], k)
sum_i_prod_k_P = self._multiply_line_evaluations(
circuit, sum_i_prod_k_P, [l1], k
)
new_new_points.append(T)
return new_new_points, sum_i_prod_k_P
elif bit_type == "0":
for k in range(self.n_pairs):
T, l1 = circuit.double_step(current_points[k], k)
Expand All @@ -320,9 +382,18 @@ def _process_points(
circuit, sum_i_prod_k_P, [l1, l2], k
)
new_points.append(T)

else:
raise ValueError(f"Invalid bit type: {bit_type}")
return new_points, sum_i_prod_k_P

def _multiply_line_evaluations(self, circuit, sum_i_prod_k_P, lines, k):
def _multiply_line_evaluations(
self,
circuit: multi_pairing_check.MultiPairingCheckCircuit,
sum_i_prod_k_P,
lines,
k,
):
for i, l in enumerate(lines):
sum_i_prod_k_P = circuit.mul(
sum_i_prod_k_P,
Expand Down Expand Up @@ -358,17 +429,27 @@ def _extend_output(self, circuit, new_points, lhs_i_plus_one, ci_plus_one):
circuit.extend_struct_output(u384(name="ci_plus_one", elmts=[ci_plus_one]))


class FixedG2MPCheckBit0(BaseFixedG2PointsMPCheck):
class FixedG2MPCheckBitBase(BaseFixedG2PointsMPCheck):
"""Base class for bit checking circuits with default parameters."""

BIT_TYPE = None # Override in subclasses
DEFAULT_PAIRS = 3
DEFAULT_FIXED_G2 = 2

def __init__(
self,
curve_id: int,
n_pairs: int,
n_fixed_g2: int,
n_pairs: int = None,
n_fixed_g2: int = None,
auto_run: bool = True,
compilation_mode: int = 1,
):
assert compilation_mode == 1, "Compilation mode 1 is required for this circuit"
n_pairs = n_pairs if n_pairs is not None else self.DEFAULT_PAIRS
n_fixed_g2 = n_fixed_g2 if n_fixed_g2 is not None else self.DEFAULT_FIXED_G2

super().__init__(
name=f"mp_check_bit0_{n_pairs}P_{n_fixed_g2}F",
name=f"mp_check_bit{self.BIT_TYPE}_{n_pairs}P_{n_fixed_g2}F",
curve_id=curve_id,
n_pairs=n_pairs,
n_fixed_g2=n_fixed_g2,
Expand All @@ -378,63 +459,30 @@ def __init__(

@property
def input_map(self):
return self._base_input_map("0")
return self._base_input_map(self.BIT_TYPE)

def _execute_circuit_logic(self, circuit, vars) -> ModuloCircuit:
return self._execute_circuit_bit_logic_base(circuit, vars, "0")
return self._execute_circuit_bit_logic_base(circuit, vars, self.BIT_TYPE)


class FixedG2MPCheckBit00(BaseFixedG2PointsMPCheck):
def __init__(
self,
curve_id: int,
auto_run: bool = True,
compilation_mode: int = 1,
n_pairs: int = 3,
n_fixed_g2: int = 2,
):
super().__init__(
name=f"mp_check_bit00_{n_pairs}P_{n_fixed_g2}F",
curve_id=curve_id,
n_pairs=n_pairs,
n_fixed_g2=n_fixed_g2,
auto_run=auto_run,
compilation_mode=compilation_mode,
)
class FixedG2MPCheckBit0(FixedG2MPCheckBitBase):
BIT_TYPE = "0"

@property
def input_map(self):
return self._base_input_map("00")

def _execute_circuit_logic(self, circuit, vars) -> ModuloCircuit:
return self._execute_circuit_bit_logic_base(circuit, vars, "00")
class FixedG2MPCheckBit00(FixedG2MPCheckBitBase):
BIT_TYPE = "00"


class FixedG2MPCheckBit1(BaseFixedG2PointsMPCheck):
def __init__(
self,
curve_id: int,
auto_run: bool = True,
n_pairs: int = 3,
n_fixed_g2: int = 2,
compilation_mode: int = 1,
):
assert compilation_mode == 1, "Compilation mode 1 is required for this circuit"
super().__init__(
name=f"mp_check_bit1_{n_pairs}P_{n_fixed_g2}F",
curve_id=curve_id,
n_pairs=n_pairs,
n_fixed_g2=n_fixed_g2,
auto_run=auto_run,
compilation_mode=compilation_mode,
)
class FixedG2MPCheckBit1(FixedG2MPCheckBitBase):
BIT_TYPE = "1"

@property
def input_map(self):
return self._base_input_map("1")

def _execute_circuit_logic(self, circuit, vars) -> ModuloCircuit:
return self._execute_circuit_bit_logic_base(circuit, vars, "1")
class FixedG2MPCheckBit01(FixedG2MPCheckBitBase):
BIT_TYPE = "01"


class FixedG2MPCheckBit10(FixedG2MPCheckBitBase):
BIT_TYPE = "10"


class FixedG2MPCheckInitBit(BaseFixedG2PointsMPCheck):
Expand Down
Loading

0 comments on commit 1837724

Please sign in to comment.