Skip to content

Commit

Permalink
Make MillerLoopResultScalingFactor struct generic on T. (keep-starkne…
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime authored Sep 11, 2024
1 parent a055ebb commit b973b2e
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 38 deletions.
29 changes: 26 additions & 3 deletions hydra/garaga/modulo_circuit_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,22 +1015,45 @@ def extract_from_circuit_output(
) -> str:
raise NotImplementedError("Never used in practice")

def serialize_input_signature(self) -> str:
bits = self.bits
if bits <= 288:
return f"{self.name}:MillerLoopResultScalingFactor<u288>"
else:
return f"{self.name}:MillerLoopResultScalingFactor<u384>"

def dump_to_circuit_input(self) -> str:
code = ""
bits = self.bits
if bits <= 288:
next_fn = "next_u288"
else:
next_fn = "next_2"
for mem_name in self.members_names:
code += f"circuit_inputs = circuit_inputs.next_2({self.name}.{mem_name});\n"
code += (
f"circuit_inputs = circuit_inputs.{next_fn}({self.name}.{mem_name});\n"
)
return code

def serialize(self, raw: bool = False) -> str:
assert len(self.elmts) == 6
raw_struct = f"{self.__class__.__name__}{{{','.join([f'{self.members_names[i]}: {int_to_u384(self.elmts[i].value)}' for i in range(len(self))])}}}"
bits = self.bits
if bits <= 288:
curve_id = 0
else:
curve_id = 1
raw_struct = f"{self.__class__.__name__}{{{','.join([f'{self.members_names[i]}: {int_to_u2XX(self.elmts[i].value, curve_id=curve_id)}' for i in range(len(self))])}}}"
if raw:
return raw_struct
else:
return f"let {self.name}:{self.__class__.__name__} = {raw_struct};\n"

def _serialize_to_calldata(self) -> list[int]:
return io.bigint_split_array(self.elmts, prepend_length=False)
bits = self.bits
if bits <= 288:
return io.bigint_split_array(self.elmts, n_limbs=3, prepend_length=False)
else:
return io.bigint_split_array(self.elmts, n_limbs=4, prepend_length=False)

def __len__(self) -> int:
if self.elmts is not None:
Expand Down
16 changes: 8 additions & 8 deletions src/src/circuits/multi_pairing_check.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2213,7 +2213,7 @@ fn run_BLS12_381_MP_CHECK_INIT_BIT_3P_2F_circuit(
return (Q0, new_lhs);
}
fn run_BLS12_381_MP_CHECK_PREPARE_LAMBDA_ROOT_circuit(
lambda_root_inverse: E12D<u384>, z: u384, scaling_factor: MillerLoopResultScalingFactor
lambda_root_inverse: E12D<u384>, z: u384, scaling_factor: MillerLoopResultScalingFactor<u384>
) -> (u384, u384, u384) {
// CONSTANT stack
let in0 = CE::<CI<0>> {}; // 0x0
Expand Down Expand Up @@ -5334,7 +5334,7 @@ fn run_BN254_MP_CHECK_INIT_BIT_3P_2F_circuit(
fn run_BN254_MP_CHECK_PREPARE_LAMBDA_ROOT_circuit(
lambda_root: E12D<u288>,
z: u384,
scaling_factor: MillerLoopResultScalingFactor,
scaling_factor: MillerLoopResultScalingFactor<u288>,
c_inv: E12D<u288>,
c_0: u384
) -> (u384, u384, u384, u384, u384, u384, u384) {
Expand Down Expand Up @@ -5678,12 +5678,12 @@ fn run_BN254_MP_CHECK_PREPARE_LAMBDA_ROOT_circuit(
circuit_inputs = circuit_inputs.next_u288(lambda_root.w10); // in55
circuit_inputs = circuit_inputs.next_u288(lambda_root.w11); // in56
circuit_inputs = circuit_inputs.next_2(z); // in57
circuit_inputs = circuit_inputs.next_2(scaling_factor.w0); // in58
circuit_inputs = circuit_inputs.next_2(scaling_factor.w2); // in59
circuit_inputs = circuit_inputs.next_2(scaling_factor.w4); // in60
circuit_inputs = circuit_inputs.next_2(scaling_factor.w6); // in61
circuit_inputs = circuit_inputs.next_2(scaling_factor.w8); // in62
circuit_inputs = circuit_inputs.next_2(scaling_factor.w10); // in63
circuit_inputs = circuit_inputs.next_u288(scaling_factor.w0); // in58
circuit_inputs = circuit_inputs.next_u288(scaling_factor.w2); // in59
circuit_inputs = circuit_inputs.next_u288(scaling_factor.w4); // in60
circuit_inputs = circuit_inputs.next_u288(scaling_factor.w6); // in61
circuit_inputs = circuit_inputs.next_u288(scaling_factor.w8); // in62
circuit_inputs = circuit_inputs.next_u288(scaling_factor.w10); // in63
circuit_inputs = circuit_inputs.next_u288(c_inv.w0); // in64
circuit_inputs = circuit_inputs.next_u288(c_inv.w1); // in65
circuit_inputs = circuit_inputs.next_u288(c_inv.w2); // in66
Expand Down
14 changes: 7 additions & 7 deletions src/src/definitions.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,13 @@ impl E12DSerde288 of Serde<E12D<u288>> {
}

#[derive(Copy, Drop, Debug, PartialEq, Serde)]
struct MillerLoopResultScalingFactor {
w0: u384,
w2: u384,
w4: u384,
w6: u384,
w8: u384,
w10: u384,
struct MillerLoopResultScalingFactor<T> {
w0: T,
w2: T,
w4: T,
w6: T,
w8: T,
w10: T,
}
#[derive(Copy, Drop, Debug, PartialEq, Serde)]
struct E12DMulQuotient {
Expand Down
4 changes: 2 additions & 2 deletions src/src/groth16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ fn multi_pairing_check_bn254_3P_2F_with_extra_miller_loop_result(
let (s0, s1, s2) = hashing::hash_G1G2Pair(pair2, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_E12D_u288(mpcheck_hint.lambda_root, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_E12D_u288(mpcheck_hint.lambda_root_inverse, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor(mpcheck_hint.w, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor_u288(mpcheck_hint.w, s0, s1, s2);
// Hash Ris to obtain base random coefficient c0
let (s0, s1, s2) = hashing::hash_E12D_u288_transcript(mpcheck_hint.Ris, s0, s1, s2);

Expand Down Expand Up @@ -514,7 +514,7 @@ fn multi_pairing_check_bls12_381_3P_2F_with_extra_miller_loop_result(
let (s0, s1, s2) = hashing::hash_G1G2Pair(pair1, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_G1G2Pair(pair2, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_E12D_u384(hint.lambda_root_inverse, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor(hint.w, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor_u384(hint.w, s0, s1, s2);
// Hash Ris to obtain base random coefficient c0
let (s0, s1, s2) = hashing::hash_E12D_u384_transcript(hint.Ris, s0, s1, s2);
let mut c_i: u384 = s1.into();
Expand Down
8 changes: 4 additions & 4 deletions src/src/pairing_check.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ use garaga::basic_field_ops::{compute_yInvXnegOverY_BN254, compute_yInvXnegOverY
struct MPCheckHintBN254 {
lambda_root: E12D<u288>,
lambda_root_inverse: E12D<u288>,
w: MillerLoopResultScalingFactor,
w: MillerLoopResultScalingFactor<u288>,
Ris: Span<E12D<u288>>,
big_Q: Array<u288>,
}

#[derive(Drop, Serde)]
struct MPCheckHintBLS12_381 {
lambda_root_inverse: E12D<u384>,
w: MillerLoopResultScalingFactor,
w: MillerLoopResultScalingFactor<u384>,
Ris: Span<E12D<u384>>,
big_Q: Array<u384>,
}
Expand All @@ -70,7 +70,7 @@ fn multi_pairing_check_bn254_2P_2F(
let (s0, s1, s2) = hashing::hash_G1G2Pair(pair1, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_E12D_u288(hint.lambda_root, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_E12D_u288(hint.lambda_root_inverse, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor(hint.w, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor_u288(hint.w, s0, s1, s2);
// Hash Ris to obtain base random coefficient c0
let (s0, s1, s2) = hashing::hash_E12D_u288_transcript(hint.Ris, s0, s1, s2);
let mut c_i: u384 = s1.into();
Expand Down Expand Up @@ -233,7 +233,7 @@ fn multi_pairing_check_bls12_381_2P_2F(
let (s0, s1, s2) = hashing::hash_G1G2Pair(pair0, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_G1G2Pair(pair1, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_E12D_u384(hint.lambda_root_inverse, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor(hint.w, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor_u384(hint.w, s0, s1, s2);
// Hash Ris to obtain base random coefficient c0
let (s0, s1, s2) = hashing::hash_E12D_u384_transcript(hint.Ris, s0, s1, s2);

Expand Down
24 changes: 12 additions & 12 deletions src/src/tests/pairing_tests.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4097,12 +4097,12 @@ mod pairing_tests {
}
},
w: MillerLoopResultScalingFactor {
w0: u384 { limb0: 0x1, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w2: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w4: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w6: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w8: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w10: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 }
w0: u288 { limb0: 0x1, limb1: 0x0, limb2: 0x0 },
w2: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w4: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w6: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w8: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w10: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 }
},
Ris: array![
E12D {
Expand Down Expand Up @@ -11917,12 +11917,12 @@ mod pairing_tests {
}
},
w: MillerLoopResultScalingFactor {
w0: u384 { limb0: 0x1, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w2: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w4: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w6: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w8: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w10: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 }
w0: u288 { limb0: 0x1, limb1: 0x0, limb2: 0x0 },
w2: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w4: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w6: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w8: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w10: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 }
},
Ris: array![
E12D {
Expand Down
30 changes: 28 additions & 2 deletions src/src/utils/hashing.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ pub fn hash_E12D_u288(

// Apply sponge construction to a MillerLoopResultScalingFactor element from an initial state (s0,
// s1, s2)
pub fn hash_MillerLoopResultScalingFactor(
elmt: MillerLoopResultScalingFactor, mut s0: felt252, mut s1: felt252, mut s2: felt252
pub fn hash_MillerLoopResultScalingFactor_u384(
elmt: MillerLoopResultScalingFactor<u384>, mut s0: felt252, mut s1: felt252, mut s2: felt252
) -> (felt252, felt252, felt252) {
let base: felt252 = 79228162514264337593543950336; // 2**96

Expand All @@ -203,6 +203,32 @@ pub fn hash_MillerLoopResultScalingFactor(
return (_s0, _s1, _s2);
}

pub fn hash_MillerLoopResultScalingFactor_u288(
elmt: MillerLoopResultScalingFactor<u288>, mut s0: felt252, mut s1: felt252, mut s2: felt252
) -> (felt252, felt252, felt252) {
let base: felt252 = 79228162514264337593543950336; // 2**96

let in_1 = s0 + elmt.w0.limb0.into() + base * elmt.w0.limb1.into();
let in_2 = s1 + elmt.w0.limb2.into();
let (_s0, _s1, _s2) = hades_permutation(in_1, in_2, s2);
let in_1 = _s0 + elmt.w2.limb0.into() + base * elmt.w2.limb1.into();
let in_2 = _s1 + elmt.w2.limb2.into();
let (_s0, _s1, _s2) = hades_permutation(in_1, in_2, _s2);
let in_1 = _s0 + elmt.w4.limb0.into() + base * elmt.w4.limb1.into();
let in_2 = _s1 + elmt.w4.limb2.into();
let (_s0, _s1, _s2) = hades_permutation(in_1, in_2, _s2);
let in_1 = _s0 + elmt.w6.limb0.into() + base * elmt.w6.limb1.into();
let in_2 = _s1 + elmt.w6.limb2.into();
let (_s0, _s1, _s2) = hades_permutation(in_1, in_2, _s2);
let in_1 = _s0 + elmt.w8.limb0.into() + base * elmt.w8.limb1.into();
let in_2 = _s1 + elmt.w8.limb2.into();
let (_s0, _s1, _s2) = hades_permutation(in_1, in_2, _s2);
let in_1 = _s0 + elmt.w10.limb0.into() + base * elmt.w10.limb1.into();
let in_2 = _s1 + elmt.w10.limb2.into();
let (_s0, _s1, _s2) = hades_permutation(in_1, in_2, _s2);
return (_s0, _s1, _s2);
}

// Apply sponge construction to a sequence of E12D elements from an initial state (s0, s1, s2)
pub fn hash_E12D_u384_transcript(
transcript: Span<E12D<u384>>, mut s0: felt252, mut s1: felt252, mut s2: felt252
Expand Down

0 comments on commit b973b2e

Please sign in to comment.