Skip to content

Commit

Permalink
feat: revamp zero prove functions
Browse files Browse the repository at this point in the history
  • Loading branch information
atanmarko committed Nov 7, 2024
1 parent 70f6116 commit 2a796b3
Showing 1 changed file with 109 additions and 22 deletions.
131 changes: 109 additions & 22 deletions zero/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use anyhow::{Context, Result};
use evm_arithmetization::Field;
use evm_arithmetization::SegmentDataIterator;
use futures::{
future, future::BoxFuture, stream::FuturesUnordered, FutureExt, TryFutureExt, TryStreamExt,
future, future::BoxFuture, stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt,
TryStreamExt,
};
use hashbrown::HashMap;
use num_traits::ToPrimitive as _;
Expand All @@ -23,7 +24,7 @@ use plonky2::plonk::circuit_data::CircuitConfig;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::Receiver;
use tokio::sync::{oneshot, Semaphore};
use tokio::sync::{mpsc, oneshot, Semaphore};
use trace_decoder::observer::DummyObserver;
use trace_decoder::{BlockTrace, OtherBlockData, WireDisposition};
use tracing::{error, info};
Expand Down Expand Up @@ -116,6 +117,8 @@ impl BlockProverInput {
WIRE_DISPOSITION,
)?;

let batch_count = block_generation_inputs.len();

// Create segment proof.
let seg_prove_ops = ops::SegmentProof {
save_inputs_on_error,
Expand All @@ -131,29 +134,113 @@ impl BlockProverInput {
save_inputs_on_error,
};

// Segment the batches, prove segments and aggregate them to resulting batch
// proofs.
let batch_proof_futs: FuturesUnordered<_> = block_generation_inputs
.iter()
.enumerate()
.map(|(idx, txn_batch)| {
let segment_data_iterator =
SegmentDataIterator::<Field>::new(txn_batch, Some(max_cpu_len_log));

Directive::map(IndexedStream::from(segment_data_iterator), &seg_prove_ops)
.fold(&seg_agg_ops)
.run(&proof_runtime.heavy_proof)
.map(move |e| {
e.map(|p| (idx, crate::proof_types::BatchAggregatableProof::from(p)))
})
// Generate channels to communicate segments of each batch to proving task
let (segment_senders, segment_receivers): (Vec<_>, Vec<_>) = (0..block_generation_inputs
.len())
.into_iter()
.map(|idx| {
let (segment_tx, segment_rx) = mpsc::channel::<evm_arithmetization::AllData>(1);
(segment_tx, segment_rx)
})
.collect();
.unzip();

let (batch_proof_tx, mut batch_proof_rx) =
mpsc::channel::<(usize, crate::proof_types::BatchAggregatableProof)>(32);

// Span a task for each batch to generate segments for that batch
// and send them to the proving task.
let _segment_generation_task = tokio::spawn(async move {
let mut batch_segment_futures: FuturesUnordered<_> = FuturesUnordered::new();

for (_idx, (txn_batch, segment_tx)) in block_generation_inputs
.into_iter()
.zip(segment_senders)
.enumerate()
{
batch_segment_futures.push(async move {
let mut segment_data_iterator =
SegmentDataIterator::<Field>::new(&txn_batch, Some(max_cpu_len_log));
while let Some(segment_data) = segment_data_iterator.next() {
segment_tx
.send(segment_data)
.await
.context("Failed to send segment data")?;
}
Ok::<(), anyhow::Error>(())
});
}
while let Some(it) = batch_segment_futures.next().await {
it?;
}
Ok::<(), anyhow::Error>(())
});

let _proving_task = tokio::spawn(async move {
let mut batch_proving_futures: FuturesUnordered<_> = FuturesUnordered::new();
// Span a proving task for each batch to generate segment proofs
// After two segment proofs are generated, aggregate them further
for (batch_idx, mut segment_rx) in segment_receivers.into_iter().enumerate() {
let batch_proof_tx = batch_proof_tx.clone();
// Tasks to dispatch and aggregate one batch
batch_proving_futures.push(async move {
let mut pair_segment_data = Vec::new();
let mut seg_aggregatable_proofs = Vec::new();
// Wait for segments and dispatch them to the segment proof worker task
// There will always be pair number of segments, so we dispatch two segments
// and aggregate them as one chained directive to save
// a bit on local and transported data size.
while let Some(segment_data) = segment_rx.recv().await {
if pair_segment_data.len() < 2 {
pair_segment_data.push(segment_data);
continue;
} else {
// Prove the segment
let seg_aggregatable_proof = Directive::map(
IndexedStream::from(&pair_segment_data),
&seg_prove_ops,
)
.fold(&seg_agg_ops)
.run(&proof_runtime.heavy_proof)
.await?;
seg_aggregatable_proofs.push(seg_aggregatable_proof);
}
}
// We have received and proved all the segments,
// now we need to aggregate to the batch proof
// Fold the batch aggregated proof stream into a single proof.
let batch_proof =
Directive::fold(IndexedStream::from(seg_aggregatable_proofs), &seg_agg_ops)
.run(&proof_runtime.light_proof)
.map(move |e| {
e.map(|p| {
(
batch_idx,
crate::proof_types::BatchAggregatableProof::from(p),
)
})
})
.await?;

batch_proof_tx.send(batch_proof).await
});
}
// Wait for all the batch proving tasks to finish
while let Some(it) = batch_proving_futures.next().await {
it?;
}
Ok::<(), anyhow::Error>(())
});

// Collect all the batch proofs for proving tasks
let mut batch_proofs: Vec<crate::proof_types::BatchAggregatableProof> = Vec::new();
while let Some(batch_proof) = batch_proof_rx.recv().await {
batch_proofs.push(batch_proof);
}

// Fold the batch aggregated proof stream into a single proof.
let final_batch_proof =
Directive::fold(IndexedStream::new(batch_proof_futs), &batch_agg_ops)
.run(&proof_runtime.light_proof)
.await?;
let final_batch_proof = Directive::fold(IndexedStream::from(batch_proofs), &batch_agg_ops)
.run(&proof_runtime.light_proof)
.await?;

if let crate::proof_types::BatchAggregatableProof::BatchAgg(proof) = final_batch_proof {
let block_number = block_number
Expand Down

0 comments on commit 2a796b3

Please sign in to comment.