Skip to content

Commit

Permalink
feat: Implement compute_shuffled_index in plonky2x and add abstracted…
Browse files Browse the repository at this point in the history
… functions for plonky2x in the utils file (#248)

* Implement compute_shuffled_index with plonky2x.
---------

Co-authored-by: Dimo99 <[email protected]>
Co-authored-by: Aneta Tsvetkova <[email protected]>
  • Loading branch information
3 people authored Nov 7, 2023
1 parent a80d15f commit d8627f2
Show file tree
Hide file tree
Showing 17 changed files with 383 additions and 23 deletions.
4 changes: 4 additions & 0 deletions casper-finality-proofs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ edition = "2021"
name = "weigh_justification_and_finalization"
path = "bin/weigh_justification_and_finalization.rs"

[[bin]]
name = "compute_shuffled_index"
path = "bin/compute_shuffled_index.rs"

[[bin]]
name = "test_engine"
path = "src/test_engine/bin/main.rs"
Expand Down
43 changes: 43 additions & 0 deletions casper-finality-proofs/bin/compute_shuffled_index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use casper_finality_proofs::compute_shuffled_index::circuit::define;
use plonky2x::prelude::{
bytes, ArrayVariable, ByteVariable, CircuitBuilder, DefaultParameters, U64Variable,
};
use plonky2x::utils;

fn main() {
utils::setup_logger();

let seed_bytes: Vec<u8> =
bytes!("0x4ac96f664a6cafd300b161720809b9e17905d4d8fed7a97ff89cf0080a953fe7");

let seed_bytes_fixed_size: [u8; 32] = seed_bytes.try_into().unwrap();

const SHUFFLE_ROUND_COUNT: u8 = 90;
let mut builder = CircuitBuilder::<DefaultParameters, 2>::new();
define(&mut builder, SHUFFLE_ROUND_COUNT);

let circuit = builder.mock_build();

const START_IDX: u64 = 0;
const COUNT: u64 = 100;
let mapping = [
53, 21, 19, 29, 76, 32, 67, 63, 3, 38, 89, 37, 30, 78, 0, 40, 96, 44, 22, 42, 23, 62, 92,
87, 11, 43, 54, 75, 71, 82, 68, 36, 59, 90, 66, 45, 58, 70, 4, 72, 33, 24, 6, 39, 52, 51,
99, 8, 27, 88, 20, 31, 86, 77, 94, 95, 85, 41, 93, 15, 13, 5, 74, 81, 18, 17, 47, 2, 16, 7,
84, 9, 79, 65, 61, 49, 60, 50, 64, 34, 55, 56, 91, 98, 28, 46, 14, 73, 12, 25, 26, 57, 83,
80, 35, 97, 69, 10, 1, 48,
];
for i in START_IDX..COUNT {
let mut input = circuit.input();

input.write::<U64Variable>(i);
input.write::<U64Variable>(COUNT);
input.write::<ArrayVariable<ByteVariable, 32>>(seed_bytes_fixed_size.to_vec());

let (_witness, mut _output) = circuit.mock_prove(&input);
let shuffled_index_res = _output.read::<U64Variable>();

println!("{} {}", mapping[i as usize], shuffled_index_res);
assert!(mapping[i as usize] == shuffled_index_res);
}
}
31 changes: 31 additions & 0 deletions casper-finality-proofs/src/compute_shuffled_index/circuit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use crate::utils::plonky2x_extensions::{assert_is_true, max};
use plonky2x::prelude::{Bytes32Variable, CircuitBuilder, PlonkParameters, U64Variable};

use super::helpers::{compute_bit, compute_byte, compute_flip, compute_pivot, compute_source};

pub fn define<L: PlonkParameters<D>, const D: usize>(
builder: &mut CircuitBuilder<L, D>,
shuffle_round_count: u8,
) {
let mut index = builder.read::<U64Variable>();
let index_count = builder.read::<U64Variable>();
let seed = builder.read::<Bytes32Variable>();

let index_lt_index_count = builder.lt(index, index_count);
assert_is_true(builder, index_lt_index_count);

for current_round in 0..shuffle_round_count {
let pivot = compute_pivot(builder, seed, index_count, current_round);
let flip = compute_flip(builder, pivot, index_count, index);

let position = max(builder, index, flip);
let source = compute_source(builder, position, seed, current_round);

let byte = compute_byte(builder, source, position);
let bit = compute_bit(builder, byte, position);

index = builder.select(bit, flip, index);
}

builder.write::<U64Variable>(index);
}
113 changes: 113 additions & 0 deletions casper-finality-proofs/src/compute_shuffled_index/helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use itertools::Itertools;
use plonky2::field::types::Field;
use plonky2x::{
frontend::vars::EvmVariable,
prelude::{
BoolVariable, ByteVariable, Bytes32Variable, CircuitBuilder, CircuitVariable,
PlonkParameters, U64Variable, Variable,
},
};

use crate::utils::plonky2x_extensions::{bits_to_variable, exp_from_bits};

/// Returns the first 8 bytes of the hashed concatenation of seed with current_round
pub fn compute_pivot<L: PlonkParameters<D>, const D: usize>(
builder: &mut CircuitBuilder<L, D>,
seed: Bytes32Variable,
index_count: U64Variable,
current_round: u8,
) -> U64Variable {
let current_round_byte: ByteVariable = ByteVariable::constant(builder, current_round);
let concatenation = [seed.as_bytes().as_slice(), &[current_round_byte]]
.concat()
.to_vec();

let hash = builder.curta_sha256(&concatenation);

let hash = U64Variable::decode(
builder,
&hash.as_bytes()[0..8]
.into_iter()
.rev()
.cloned()
.collect_vec(),
);

builder.rem(hash, index_count)
}

/// Returns the computation of (pivot + index_count - index) % index_count
pub fn compute_flip<L: PlonkParameters<D>, const D: usize>(
builder: &mut CircuitBuilder<L, D>,
pivot: U64Variable,
index_count: U64Variable,
index: U64Variable,
) -> U64Variable {
let sum_pivot_index_count = builder.add(pivot, index_count);
let sum_pivot_index_count_sub_index = builder.sub(sum_pivot_index_count, index);

builder.rem(sum_pivot_index_count_sub_index, index_count)
}

/// Returns the hashed concatenation of seed, current_round and position divided by 256
pub fn compute_source<L: PlonkParameters<D>, const D: usize>(
builder: &mut CircuitBuilder<L, D>,
position: U64Variable,
seed: Bytes32Variable,
current_round: u8,
) -> Bytes32Variable {
let current_round_byte = ByteVariable::constant(builder, current_round as u8);
let const_256 = builder.constant::<U64Variable>(256);
let position_div_256 = builder.div(position, const_256);
let position_div_256_bytes = builder
.to_le_bits(position_div_256)
.chunks(8)
.take(4)
.map(|byte| ByteVariable(byte.iter().rev().cloned().collect_vec().try_into().unwrap()))
.collect_vec();

builder.curta_sha256(
&[
seed.as_bytes().as_slice(),
&[current_round_byte],
position_div_256_bytes.as_slice(),
]
.concat(),
)
}

/// Returns the byte in source at index (position % 256) / 8
pub fn compute_byte<L: PlonkParameters<D>, const D: usize>(
builder: &mut CircuitBuilder<L, D>,
source_array: Bytes32Variable,
position: U64Variable,
) -> ByteVariable {
let const_8 = builder.constant::<U64Variable>(8);
let const_256 = builder.constant::<U64Variable>(256);
let position_mod_256 = builder.rem(position, const_256);
let position_mod_256_div_8 = builder.div(position_mod_256, const_8);
let position_mod_256_div_8_bits = builder.to_le_bits(position_mod_256_div_8);
let position_mod_256_div_8_variable = bits_to_variable(builder, &position_mod_256_div_8_bits);

builder.select_array(&source_array.0 .0, position_mod_256_div_8_variable)
}

/// Returns the remainder of byte / 2^(position % 8) and 2 as BoolVariable
pub fn compute_bit<L: PlonkParameters<D>, const D: usize>(
builder: &mut CircuitBuilder<L, D>,
byte: ByteVariable,
position: U64Variable,
) -> BoolVariable {
let const_0: Variable = builder.constant(L::Field::from_canonical_usize(0));
let const_2: Variable = builder.constant(L::Field::from_canonical_usize(2));
let byte_to_variable = byte.to_variable(builder);
let byte_u64 = U64Variable::from_variables(builder, &[byte_to_variable, const_0]);

let position_first_3_bits = &builder.to_le_bits(position)[..3];
let const_2_pow_position_first_3_bits = exp_from_bits(builder, const_2, &position_first_3_bits);
let const_2_pow_position_first_3_bits_u64 =
U64Variable::from_variables(builder, &[const_2_pow_position_first_3_bits, const_0]);
let bit = builder.div(byte_u64, const_2_pow_position_first_3_bits_u64);

builder.to_le_bits(bit)[0]
}
3 changes: 3 additions & 0 deletions casper-finality-proofs/src/compute_shuffled_index/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod circuit;

mod helpers;
1 change: 1 addition & 0 deletions casper-finality-proofs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod compute_shuffled_index;
pub mod constants;
pub mod test_engine;
mod types;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use core::fmt::Debug;
use ethers::types::H256;
use serde_derive::{Deserialize, Serialize};

#[derive(Debug, Default, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct TestData {
pub count: u64,
pub seed: H256,
pub mapping: Vec<u64>,
}
1 change: 1 addition & 0 deletions casper-finality-proofs/src/test_engine/types/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod compute_shuffled_index_data;
30 changes: 30 additions & 0 deletions casper-finality-proofs/src/test_engine/utils/setup.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
use super::test_engine::TestCase;
use crate::test_engine::wrappers::compute_shuffled_index::wrapper_mainnet::{
wrapper as wrapper_mainnet, MAINNET_CIRCUIT as circuit_mainnet,
};
use crate::test_engine::wrappers::compute_shuffled_index::wrapper_minimal::{
wrapper as wrapper_minimal, MINIMAL_CIRCUIT as circuit_minimal,
};
use crate::test_engine::wrappers::wrapper_weigh_justification_and_finalization::{
wrapper as wrapper_weigh_justification_and_finalization,
CIRCUIT as circuit_weigh_justification_and_finalization,
Expand All @@ -8,6 +14,8 @@ use strum::{Display, EnumString};

#[derive(Debug, Eq, Hash, PartialEq, Copy, Clone, EnumString, Display)]
pub enum TestWrappers {
WrapperComputeShuffledIndexConsensusMainnet,
WrapperComputeShuffledIndexConsensusMinimal,
WrapperWeighJustificationAndFinalizationConsensusMainnet,
}

Expand All @@ -26,6 +34,18 @@ pub fn map_test_to_wrapper(
wrapper_weigh_justification_and_finalization(path, should_assert)
}),
),
TestWrappers::WrapperComputeShuffledIndexConsensusMainnet => (
Box::new(|| {
Lazy::force(&circuit_mainnet);
}),
Box::new(|path, should_assert| wrapper_mainnet(&path, should_assert)),
),
TestWrappers::WrapperComputeShuffledIndexConsensusMinimal => (
Box::new(|| {
Lazy::force(&circuit_minimal);
}),
Box::new(|path, should_assert| wrapper_minimal(&path, should_assert)),
),
}
}

Expand All @@ -36,6 +56,16 @@ pub fn init_tests() -> Vec<TestCase> {
"../vendor/consensus-spec-tests/tests/mainnet/capella/epoch_processing/justification_and_finalization/pyspec_tests/".to_string(),
true,
));
tests.push(TestCase::new(
TestWrappers::WrapperComputeShuffledIndexConsensusMainnet,
"../vendor/consensus-spec-tests/tests/mainnet/phase0/shuffling/core/shuffle".to_string(),
true,
));
tests.push(TestCase::new(
TestWrappers::WrapperComputeShuffledIndexConsensusMinimal,
"../vendor/consensus-spec-tests/tests/minimal/phase0/shuffling/core/shuffle".to_string(),
true,
));

tests
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod wrapper_mainnet;
pub mod wrapper_minimal;
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use crate::assert_equal;
use crate::compute_shuffled_index::circuit::define;
use crate::test_engine::types::compute_shuffled_index_data::TestData;
use crate::test_engine::utils::parsers::parse_file::read_fixture;
use once_cell::sync::Lazy;
use plonky2x::backend::circuit::MockCircuitBuild;
use plonky2x::prelude::{Bytes32Variable, U64Variable};
use plonky2x::prelude::{CircuitBuilder, DefaultParameters};

// Singleton-like pattern
pub static MAINNET_CIRCUIT: Lazy<MockCircuitBuild<DefaultParameters, 2>> = Lazy::new(|| {
let mut builder = CircuitBuilder::<DefaultParameters, 2>::new();
define(&mut builder, 90);
builder.mock_build()
});

pub fn wrapper(path: &str, should_assert: bool) -> Result<String, anyhow::Error> {
let json_data: TestData = read_fixture::<TestData>(path);

let mut result_indices: Vec<u64> = Vec::new();

for i in 0..json_data.count {
let mut input = MAINNET_CIRCUIT.input();

input.write::<U64Variable>(i);
input.write::<U64Variable>(json_data.count);
input.write::<Bytes32Variable>(json_data.seed);

let (_witness, mut _output) = MAINNET_CIRCUIT.mock_prove(&input);
let shuffled_index_res = _output.read::<U64Variable>();
if should_assert {
assert_equal!(json_data.mapping[i as usize], shuffled_index_res);
}

result_indices.push(shuffled_index_res);
}

Ok(format!("{:?}", result_indices))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use crate::assert_equal;
use crate::compute_shuffled_index::circuit::define;
use crate::test_engine::types::compute_shuffled_index_data::TestData;
use crate::test_engine::utils::parsers::parse_file::read_fixture;
use once_cell::sync::Lazy;
use plonky2x::backend::circuit::MockCircuitBuild;
use plonky2x::prelude::{Bytes32Variable, U64Variable};
use plonky2x::prelude::{CircuitBuilder, DefaultParameters};

// Singleton-like pattern
pub static MINIMAL_CIRCUIT: Lazy<MockCircuitBuild<DefaultParameters, 2>> = Lazy::new(|| {
let mut builder = CircuitBuilder::<DefaultParameters, 2>::new();
define(&mut builder, 10);
builder.mock_build()
});

pub fn wrapper(path: &str, should_assert: bool) -> Result<String, anyhow::Error> {
let json_data: TestData = read_fixture::<TestData>(path);

let mut result_indices: Vec<u64> = Vec::new();

for i in 0..json_data.count {
let mut input = MINIMAL_CIRCUIT.input();

input.write::<U64Variable>(i);
input.write::<U64Variable>(json_data.count);
input.write::<Bytes32Variable>(json_data.seed);

let (_witness, mut _output) = MINIMAL_CIRCUIT.mock_prove(&input);
let shuffled_index_res = _output.read::<U64Variable>();
if should_assert {
assert_equal!(json_data.mapping[i as usize], shuffled_index_res);
}

result_indices.push(shuffled_index_res);
}

Ok(format!("{:?}", result_indices))
}
1 change: 1 addition & 0 deletions casper-finality-proofs/src/test_engine/wrappers/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod wrapper_weigh_justification_and_finalization;
pub mod compute_shuffled_index;
19 changes: 0 additions & 19 deletions casper-finality-proofs/src/utils/bits.rs

This file was deleted.

3 changes: 1 addition & 2 deletions casper-finality-proofs/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
pub mod bits;
pub mod plonky2x_extensions;
pub mod plonky2x_extensions;
Loading

0 comments on commit d8627f2

Please sign in to comment.