From f98e992ae3d22474e1f908f37d5324200ef6402d Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Mon, 23 Oct 2023 13:20:06 +0000 Subject: [PATCH] Block machine queries --- compiler/tests/asm.rs | 8 ++ executor/src/witgen/generator.rs | 4 +- executor/src/witgen/machines/block_machine.rs | 6 +- executor/src/witgen/processor.rs | 79 +++++++++++++------ executor/src/witgen/query_processor.rs | 19 ++--- executor/src/witgen/sequence_iterator.rs | 36 +++++---- executor/src/witgen/vm_processor.rs | 38 ++++++--- test_data/asm/sqrt.asm | 54 +++++++++++++ 8 files changed, 175 insertions(+), 69 deletions(-) create mode 100644 test_data/asm/sqrt.asm diff --git a/compiler/tests/asm.rs b/compiler/tests/asm.rs index a97ad36076..8b0a6ea819 100644 --- a/compiler/tests/asm.rs +++ b/compiler/tests/asm.rs @@ -228,6 +228,14 @@ fn test_bit_access() { //gen_estark_proof(f, slice_to_vec(&i)); } +#[test] +fn test_sqrt() { + let f = "sqrt.asm"; + verify_asm::(f, Default::default()); + gen_halo2_proof(f, Default::default()); + gen_estark_proof(f, Default::default()); +} + #[test] fn functional_instructions() { let f = "functional_instructions.asm"; diff --git a/executor/src/witgen/generator.rs b/executor/src/witgen/generator.rs index 151f7cf85d..1499280509 100644 --- a/executor/src/witgen/generator.rs +++ b/executor/src/witgen/generator.rs @@ -8,7 +8,6 @@ use crate::witgen::rows::CellValue; use super::affine_expression::AffineExpression; use super::column_map::WitnessColumnMap; use super::global_constraints::GlobalConstraints; -use super::identity_processor::IdentityProcessor; use super::machines::Machine; use super::processor::Processor; @@ -104,7 +103,6 @@ impl<'a, T: FieldElement> Generator<'a, T> { // Note that using `Processor` instead of `VmProcessor` is more convenient here because // it does not assert that the row is "complete" afterwards (i.e., that all identities // are satisfied assuming 0 for unknown values). - let mut identity_processor = IdentityProcessor::new(self.fixed_data, mutable_state); let row_factory = RowFactory::new(self.fixed_data, self.global_range_constraints.clone()); let data = vec![ row_factory.fresh_row(self.fixed_data.degree - 1), @@ -113,7 +111,7 @@ impl<'a, T: FieldElement> Generator<'a, T> { let mut processor = Processor::new( self.fixed_data.degree - 1, data, - &mut identity_processor, + mutable_state, &self.identities, self.fixed_data, row_factory, diff --git a/executor/src/witgen/machines/block_machine.rs b/executor/src/witgen/machines/block_machine.rs index bed97c77bc..21e141e28d 100644 --- a/executor/src/witgen/machines/block_machine.rs +++ b/executor/src/witgen/machines/block_machine.rs @@ -259,8 +259,6 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> { log::trace!(" {}", l); } - let mut identity_processor = IdentityProcessor::new(self.fixed_data, mutable_state); - // First check if we already store the value. // This can happen in the loop detection case, where this function is just called // to validate the constraints. @@ -281,6 +279,7 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> { UnknownStrategy::Unknown, ); + let mut identity_processor = IdentityProcessor::new(self.fixed_data, mutable_state); let result = identity_processor.process_link(left, right, &row_pair)?; if result.is_complete() { @@ -357,11 +356,10 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> { let block = (0..(self.block_size + 2)) .map(|i| self.row_factory.fresh_row(i as DegreeType + row_offset)) .collect(); - let mut identity_processor = IdentityProcessor::new(self.fixed_data, mutable_state); let mut processor = Processor::new( row_offset, block, - &mut identity_processor, + mutable_state, &self.identities, self.fixed_data, self.row_factory.clone(), diff --git a/executor/src/witgen/processor.rs b/executor/src/witgen/processor.rs index a34ceea868..26c5b323a3 100644 --- a/executor/src/witgen/processor.rs +++ b/executor/src/witgen/processor.rs @@ -6,14 +6,15 @@ use ast::{ }; use number::FieldElement; -use crate::witgen::Constraint; +use crate::witgen::{query_processor::QueryProcessor, Constraint}; use super::{ affine_expression::AffineExpression, + column_map::WitnessColumnMap, identity_processor::IdentityProcessor, rows::{Row, RowFactory, RowPair, RowUpdater, UnknownStrategy}, - sequence_iterator::{IdentityInSequence, ProcessingSequenceIterator, SequenceStep}, - Constraints, EvalError, EvalValue, FixedData, QueryCallback, + sequence_iterator::{Action, ProcessingSequenceIterator, SequenceStep}, + Constraints, EvalError, EvalValue, FixedData, MutableState, QueryCallback, }; type Left<'a, T> = Vec>; @@ -50,14 +51,16 @@ pub struct Processor<'a, 'b, 'c, T: FieldElement, Q: QueryCallback, CalldataA data: Vec>, /// The list of identities identities: &'c [&'a Identity>], - /// The identity processor - identity_processor: &'c mut IdentityProcessor<'a, 'b, 'c, T, Q>, + /// The mutable state + mutable_state: &'c mut MutableState<'a, 'b, T, Q>, /// The fixed data (containing information about all columns) fixed_data: &'a FixedData<'a, T>, /// The row factory row_factory: RowFactory<'a, T>, /// The set of witness columns that are actually part of this machine. witness_cols: &'c HashSet, + /// Whether a given witness column is relevant for this machine (faster than doing a contains check on witness_cols) + is_relevant_witness: WitnessColumnMap, /// The outer query, if any. If there is none, processing an outer query will fail. outer_query: Option>, _marker: PhantomData, @@ -69,20 +72,27 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> pub fn new( row_offset: u64, data: Vec>, - identity_processor: &'c mut IdentityProcessor<'a, 'b, 'c, T, Q>, + mutable_state: &'c mut MutableState<'a, 'b, T, Q>, identities: &'c [&'a Identity>], fixed_data: &'a FixedData<'a, T>, row_factory: RowFactory<'a, T>, witness_cols: &'c HashSet, ) -> Self { + let is_relevant_witness = WitnessColumnMap::from( + fixed_data + .witness_cols + .keys() + .map(|poly_id| witness_cols.contains(&poly_id)), + ); Self { row_offset, data, - identity_processor, + mutable_state, identities, fixed_data, row_factory, witness_cols, + is_relevant_witness, outer_query: None, _marker: PhantomData, } @@ -97,11 +107,12 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> _marker: PhantomData, row_offset: self.row_offset, data: self.data, - identity_processor: self.identity_processor, + mutable_state: self.mutable_state, identities: self.identities, fixed_data: self.fixed_data, row_factory: self.row_factory, witness_cols: self.witness_cols, + is_relevant_witness: self.is_relevant_witness, } } @@ -124,6 +135,7 @@ impl<'a, 'b, T: FieldElement, Q: QueryCallback, CalldataAvailable> /// If any identity was unsatisfied, returns an error. #[allow(dead_code)] pub fn check_constraints(&mut self) -> Result<(), EvalError> { + let mut identity_processor = IdentityProcessor::new(self.fixed_data, self.mutable_state); for i in 0..(self.data.len() - 1) { let row_pair = RowPair::new( &self.data[i], @@ -133,8 +145,7 @@ impl<'a, 'b, T: FieldElement, Q: QueryCallback, CalldataAvailable> UnknownStrategy::Zero, ); for identity in self.identities { - self.identity_processor - .process_identity(identity, &row_pair)?; + identity_processor.process_identity(identity, &row_pair)?; } } Ok(()) @@ -148,27 +159,46 @@ impl<'a, 'b, T: FieldElement, Q: QueryCallback, CalldataAvailable> ) -> Result, EvalError> { let mut outer_assignments = vec![]; - while let Some(SequenceStep { - row_delta, - identity, - }) = sequence_iterator.next() - { + while let Some(SequenceStep { row_delta, action }) = sequence_iterator.next() { let row_index = (1 + row_delta) as usize; - let progress = match identity { - IdentityInSequence::Internal(identity_index) => { + let progress = match action { + Action::InternalIdentity(identity_index) => { self.process_identity(row_index, identity_index)? } - IdentityInSequence::OuterQuery => { + Action::OuterQuery => { let (progress, new_outer_assignments) = self.process_outer_query(row_index)?; outer_assignments.extend(new_outer_assignments); progress } + Action::ProverQueries => self.process_queries(row_index), }; sequence_iterator.report_progress(progress); } Ok(outer_assignments) } + fn process_queries(&mut self, row_index: usize) -> bool { + let mut progress = false; + for poly_id in self.fixed_data.witness_cols.keys() { + if self.is_relevant_witness[&poly_id] { + let mut query_processor = + QueryProcessor::new(self.fixed_data, self.mutable_state.query_callback); + let global_row_index = self.row_offset + row_index as u64; + let row_pair = RowPair::new( + &self.data[row_index], + &self.data[row_index + 1], + global_row_index, + self.fixed_data, + UnknownStrategy::Unknown, + ); + + let updates = query_processor.process_query(&row_pair, &poly_id); + progress |= self.apply_updates(row_index, &updates, || "query".to_string()); + } + } + progress + } + /// Given a row and identity index, computes any updates, applies them and returns /// whether any progress was made. fn process_identity( @@ -189,8 +219,8 @@ impl<'a, 'b, T: FieldElement, Q: QueryCallback, CalldataAvailable> ); // Compute updates - let updates = self - .identity_processor + let mut identity_processor = IdentityProcessor::new(self.fixed_data, self.mutable_state); + let updates = identity_processor .process_identity(identity, &row_pair) .map_err(|e| { log::warn!("Error in identity: {identity}"); @@ -229,8 +259,8 @@ impl<'a, 'b, T: FieldElement, Q: QueryCallback, CalldataAvailable> UnknownStrategy::Unknown, ); - let updates = self - .identity_processor + let mut identity_processor = IdentityProcessor::new(self.fixed_data, self.mutable_state); + let updates = identity_processor .process_link(left, right, &row_pair) .map_err(|e| { log::warn!("Error in outer query: {e}"); @@ -303,7 +333,7 @@ mod tests { witgen::{ column_map::FixedColumnMap, global_constraints::GlobalConstraints, - identity_processor::{IdentityProcessor, Machines}, + identity_processor::Machines, machines::FixedLookup, rows::RowFactory, sequence_iterator::{DefaultSequenceIterator, ProcessingSequenceIterator}, @@ -354,7 +384,6 @@ mod tests { machines: Machines::from(machines.iter_mut()), query_callback: &mut query_callback, }; - let mut identity_processor = IdentityProcessor::new(&fixed_data, &mut mutable_state); let row_offset = 0; let identities = analyzed.identities.iter().collect::>(); let witness_cols = fixed_data.witness_cols.keys().collect(); @@ -362,7 +391,7 @@ mod tests { let mut processor = Processor::new( row_offset, data, - &mut identity_processor, + &mut mutable_state, &identities, &fixed_data, row_factory, diff --git a/executor/src/witgen/query_processor.rs b/executor/src/witgen/query_processor.rs index ba3c10ab5e..fb17eb044e 100644 --- a/executor/src/witgen/query_processor.rs +++ b/executor/src/witgen/query_processor.rs @@ -1,5 +1,5 @@ use ast::{ - analyzed::{Expression, PolynomialReference, Reference}, + analyzed::{Expression, PolyID, PolynomialReference, Reference}, parsed::{MatchArm, MatchPattern}, }; use number::FieldElement; @@ -23,19 +23,20 @@ where } } - pub fn process_queries_on_current_row( + pub fn process_query( &mut self, rows: &RowPair, + poly_id: &PolyID, ) -> EvalValue<&'a PolynomialReference, T> { - let mut eval_value = EvalValue::complete(vec![]); - for column in self.fixed_data.witness_cols.values() { - if let Some(query) = column.query.as_ref() { - if rows.get_value(&query.poly).is_none() { - eval_value.combine(self.process_witness_query(query, rows)); - } + let column = &self.fixed_data.witness_cols[poly_id]; + + if let Some(query) = column.query.as_ref() { + if rows.get_value(&query.poly).is_none() { + return self.process_witness_query(query, rows); } } - eval_value + // Either no query or the value is already known. + EvalValue::complete(vec![]) } fn process_witness_query( diff --git a/executor/src/witgen/sequence_iterator.rs b/executor/src/witgen/sequence_iterator.rs index e763202b8c..cfb2ee9d20 100644 --- a/executor/src/witgen/sequence_iterator.rs +++ b/executor/src/witgen/sequence_iterator.rs @@ -7,7 +7,7 @@ use super::affine_expression::AffineExpression; #[derive(Clone, Debug)] pub struct SequenceStep { pub row_delta: i64, - pub identity: IdentityInSequence, + pub action: Action, } /// Goes through all rows of the block machine (plus the ones before and after) @@ -24,8 +24,9 @@ pub struct DefaultSequenceIterator { progress_in_current_round: bool, /// The current row delta index. cur_row_delta_index: usize, - /// The current identity index. - cur_identity_index: usize, + /// Index of the current action. Actions are: + /// [process identity 1, ..., process identity , process queries, process outer query (if on outer_query_row)] + cur_action_index: usize, /// The number of rounds for the current row delta. /// If this number gets too large, we will assume that we're in an infinite loop and exit. current_round_count: usize, @@ -49,7 +50,7 @@ impl DefaultSequenceIterator { is_first: true, progress_in_current_round: false, cur_row_delta_index: 0, - cur_identity_index: 0, + cur_action_index: 0, current_round_count: 0, progress_steps: vec![], } @@ -60,25 +61,25 @@ impl DefaultSequenceIterator { /// Otherwise, starts with identity 0 and moves to the next row if no progress was made. fn update_state(&mut self) { if !self.is_first { - if self.is_last_identity() { + if self.is_last_action() { self.start_next_round(); } else { // Stay at row delta, move to next identity. - self.cur_identity_index += 1; + self.cur_action_index += 1; } } self.is_first = false; } - fn is_last_identity(&self) -> bool { + fn is_last_action(&self) -> bool { let row_delta = self.row_deltas[self.cur_row_delta_index]; let is_on_row_with_outer_query = self.outer_query_row == Some(row_delta); if is_on_row_with_outer_query { - // In the last row, we want to process one more identity, the outer query. - self.cur_identity_index == self.identities_count + // In the last row, we want to do one more action, processing the outer query. + self.cur_action_index == self.identities_count + 1 } else { - self.cur_identity_index == self.identities_count - 1 + self.cur_action_index == self.identities_count } } @@ -96,7 +97,7 @@ impl DefaultSequenceIterator { // Stay and current row delta, starting with identity 0. self.current_round_count += 1; } - self.cur_identity_index = 0; + self.cur_action_index = 0; self.progress_in_current_round = false; } @@ -123,19 +124,20 @@ impl DefaultSequenceIterator { fn current_step(&self) -> SequenceStep { SequenceStep { row_delta: self.row_deltas[self.cur_row_delta_index], - identity: if self.cur_identity_index < self.identities_count { - IdentityInSequence::Internal(self.cur_identity_index) - } else { - IdentityInSequence::OuterQuery + action: match self.cur_action_index.cmp(&self.identities_count) { + std::cmp::Ordering::Less => Action::InternalIdentity(self.cur_action_index), + std::cmp::Ordering::Equal => Action::ProverQueries, + std::cmp::Ordering::Greater => Action::OuterQuery, }, } } } #[derive(Clone, Copy, Debug)] -pub enum IdentityInSequence { - Internal(usize), +pub enum Action { + InternalIdentity(usize), OuterQuery, + ProverQueries, } #[derive(PartialOrd, Ord, PartialEq, Eq, Debug)] diff --git a/executor/src/witgen/vm_processor.rs b/executor/src/witgen/vm_processor.rs index 6e1eb0c2d0..efd74cae3c 100644 --- a/executor/src/witgen/vm_processor.rs +++ b/executor/src/witgen/vm_processor.rs @@ -10,6 +10,7 @@ use crate::witgen::identity_processor::{self, IdentityProcessor}; use crate::witgen::rows::RowUpdater; use crate::witgen::IncompleteCause; +use super::column_map::WitnessColumnMap; use super::query_processor::QueryProcessor; use super::rows::{Row, RowFactory, RowPair, UnknownStrategy}; @@ -40,6 +41,8 @@ pub struct VmProcessor<'a, T: FieldElement> { row_offset: DegreeType, /// The witness columns belonging to this machine witnesses: HashSet, + /// Whether a given witness column is relevant for this machine (faster than doing a contains check on witnesses) + is_relevant_witness: WitnessColumnMap, fixed_data: &'a FixedData<'a, T>, /// The subset of identities that contains a reference to the next row /// (precomputed once for performance reasons) @@ -68,9 +71,17 @@ impl<'a, T: FieldElement> VmProcessor<'a, T> { .iter() .partition(|identity| identity.contains_next_ref()); + let is_relevant_witness = WitnessColumnMap::from( + fixed_data + .witness_cols + .keys() + .map(|poly_id| witnesses.contains(&poly_id)), + ); + VmProcessor { row_offset, witnesses, + is_relevant_witness, fixed_data, identities_with_next_ref: identities_with_next, identities_without_next_ref: identities_without_next, @@ -261,17 +272,22 @@ impl<'a, T: FieldElement> VmProcessor<'a, T> { UnknownStrategy::Unknown, &mut identity_processor, )?; - let mut query_processor = - QueryProcessor::new(self.fixed_data, mutable_state.query_callback); - let row_pair = RowPair::new( - self.row(row_index), - self.row(row_index + 1), - row_index + self.row_offset, - self.fixed_data, - UnknownStrategy::Unknown, - ); - let updates = query_processor.process_queries_on_current_row(&row_pair); - progress |= self.apply_updates(row_index, &updates, || "query".to_string()); + for poly_id in self.fixed_data.witness_cols.keys() { + if self.is_relevant_witness[&poly_id] { + let mut query_processor = + QueryProcessor::new(self.fixed_data, mutable_state.query_callback); + let row_pair = RowPair::new( + self.row(row_index), + self.row(row_index + 1), + row_index + self.row_offset, + self.fixed_data, + UnknownStrategy::Unknown, + ); + + let updates = query_processor.process_query(&row_pair, &poly_id); + progress |= self.apply_updates(row_index, &updates, || "query".to_string()); + } + } if !progress { break; diff --git a/test_data/asm/sqrt.asm b/test_data/asm/sqrt.asm new file mode 100644 index 0000000000..e3915843bb --- /dev/null +++ b/test_data/asm/sqrt.asm @@ -0,0 +1,54 @@ +machine Sqrt(latch, operation_id) { + + operation sqrt<0> x -> y; + + col fixed operation_id = [0]*; + col fixed latch = [1]*; + // Only works for small results, to keep the degree of this example small. + col fixed range(i) { i % 8 }; + col witness x; + + // Witness generation is not smart enough to figure out that + // there is a unique witness, so we provide it as a hint. + // This is a dummy example that hard-codes the answer for an input of 4. + // Once we have a sqrt function that we can run to compute the query result, + // this can be used to compute the hint from x. + col witness y(i) query ("hint", 2); + + y * y = x; + + // Note that this is required to make the witness unique + // (y := -y would also satisfy y * y = x, but we want the positive solution). + { y } in { range }; +} + + +machine Main { + degree 8; + + Sqrt sqrt; + + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + reg A; + + col witness XInv; + col witness XIsZero; + XIsZero = 1 - X * XInv; + XIsZero * X = 0; + XIsZero * (1 - XIsZero) = 0; + + instr assert_zero X { XIsZero = 1 } + + instr sqrt X -> Y = sqrt.sqrt + + + function main { + + A <== sqrt(4); + assert_zero A - 2; + + return; + } +} \ No newline at end of file