diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 89f908ac1f..0b0036418f 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -33,6 +33,7 @@ pub fn optimize(mut pil_file: Analyzed) -> Analyzed { remove_constant_witness_columns(&mut pil_file); remove_constant_intermediate_columns(&mut pil_file); simplify_identities(&mut pil_file); + remove_equal_constrained_witness_columns(&mut pil_file); remove_trivial_identities(&mut pil_file); remove_duplicate_identities(&mut pil_file); @@ -132,7 +133,7 @@ fn remove_unreferenced_definitions(pil_file: &mut Analyzed) Box::new(value.iter().flat_map(|v| { v.all_children().flat_map(|e| { if let AlgebraicExpression::Reference(AlgebraicReference { poly_id, .. }) = e { - Some(poly_id_to_definition_name[poly_id].into()) + Some(poly_id_to_definition_name[poly_id].0.into()) } else { None } @@ -167,33 +168,37 @@ fn remove_unreferenced_definitions(pil_file: &mut Analyzed) pil_file.remove_trait_impls(&impls_to_remove); } -/// Builds a lookup-table that can be used to turn array elements -/// (in form of their poly ids) into the names of the arrays. +/// Builds a lookup-table that can be used to turn all poly ids into the names of the symbols that define them. +/// For array elements, this contains the array name and the index of the element in the array. fn build_poly_id_to_definition_name_lookup( pil_file: &Analyzed, -) -> BTreeMap { - let mut poly_id_to_definition_name = BTreeMap::new(); - #[allow(clippy::iter_over_hash_type)] - for (name, (symbol, _)) in &pil_file.definitions { - if matches!(symbol.kind, SymbolKind::Poly(_)) { - symbol.array_elements().for_each(|(_, id)| { - poly_id_to_definition_name.insert(id, name); - }); - } - } - #[allow(clippy::iter_over_hash_type)] - for (name, (symbol, _)) in &pil_file.intermediate_columns { - symbol.array_elements().for_each(|(_, id)| { - poly_id_to_definition_name.insert(id, name); - }); - } - poly_id_to_definition_name +) -> BTreeMap)> { + pil_file + .definitions + .iter() + .map(|(name, (symbol, _))| (name, symbol)) + .chain( + pil_file + .intermediate_columns + .iter() + .map(|(name, (symbol, _))| (name, symbol)), + ) + .filter(|(_, symbol)| matches!(symbol.kind, SymbolKind::Poly(_))) + .flat_map(|(name, symbol)| { + symbol + .array_elements() + .enumerate() + .map(move |(idx, (_, id))| { + let array_pos = symbol.is_array().then_some(idx); + (id, (name, array_pos)) + }) + }) + .collect() } - /// Collect all names that are referenced in identities and public declarations. fn collect_required_symbols<'a, T: FieldElement>( pil_file: &'a Analyzed, - poly_id_to_definition_name: &BTreeMap, + poly_id_to_definition_name: &BTreeMap)>, ) -> HashSet> { let mut required_names: HashSet> = Default::default(); required_names.extend( @@ -212,7 +217,7 @@ fn collect_required_symbols<'a, T: FieldElement>( for id in &pil_file.identities { id.pre_visit_expressions(&mut |e: &AlgebraicExpression| { if let AlgebraicExpression::Reference(AlgebraicReference { poly_id, .. }) = e { - required_names.insert(poly_id_to_definition_name[poly_id].into()); + required_names.insert(poly_id_to_definition_name[poly_id].0.into()); } }); } @@ -615,23 +620,31 @@ fn constrained_to_constant( None } -/// Removes identities that evaluate to zero and lookups with empty columns. +/// Removes identities that evaluate to zero (including constraints of the form "X = X") and lookups with empty columns. fn remove_trivial_identities(pil_file: &mut Analyzed) { let to_remove = pil_file .identities .iter() .enumerate() .filter_map(|(index, identity)| match identity { - Identity::Polynomial(PolynomialIdentity { expression, .. }) => { - if let AlgebraicExpression::Number(n) = expression { - if *n == 0.into() { - return Some(index); - } - // Otherwise the constraint is not satisfiable, - // but better to get the error elsewhere. + Identity::Polynomial(PolynomialIdentity { expression, .. }) => match expression { + AlgebraicExpression::Number(n) => { + // Return None for non-satisfiable constraints - better to get the error elsewhere + (*n == 0.into()).then_some(index) } - None - } + AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { + left, + op: AlgebraicBinaryOperator::Sub, + right, + }) => match (left.as_ref(), right.as_ref()) { + ( + AlgebraicExpression::Reference(left), + AlgebraicExpression::Reference(right), + ) => (left == right).then_some(index), + _ => None, + }, + _ => None, + }, Identity::Lookup(LookupIdentity { left, right, .. }) | Identity::Permutation(PermutationIdentity { left, right, .. }) | Identity::PhantomLookup(PhantomLookupIdentity { left, right, .. }) @@ -770,3 +783,109 @@ fn remove_duplicate_identities(pil_file: &mut Analyzed) { .collect(); pil_file.remove_identities(&to_remove); } + +/// Identifies witness columns that are directly constrained to be equal to other witness columns +/// through polynomial identities of the form "x = y" and returns a tuple ((name, id), (name, id)) +/// for each pair of identified columns +fn equal_constrained( + expression: &AlgebraicExpression, + poly_id_to_array_elem: &BTreeMap)>, +) -> Option<((String, PolyID), (String, PolyID))> { + match expression { + AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { + left, + op: AlgebraicBinaryOperator::Sub, + right, + }) => match (left.as_ref(), right.as_ref()) { + (AlgebraicExpression::Reference(l), AlgebraicExpression::Reference(r)) => { + let is_valid = |x: &AlgebraicReference, left: bool| { + x.is_witness() + && if left { + // We don't allow the left-hand side to be an array element + // to preserve array integrity (e.g. `x = y` is valid, but `x[0] = y` is not) + poly_id_to_array_elem.get(&x.poly_id).unwrap().1.is_none() + } else { + true + } + }; + + if is_valid(l, true) && is_valid(r, false) && r.next == l.next { + Some(if l.poly_id > r.poly_id { + ((l.name.clone(), l.poly_id), (r.name.clone(), r.poly_id)) + } else { + ((r.name.clone(), r.poly_id), (l.name.clone(), l.poly_id)) + }) + } else { + None + } + } + _ => None, + }, + _ => None, + } +} + +fn remove_equal_constrained_witness_columns(pil_file: &mut Analyzed) { + let poly_id_to_array_elem = build_poly_id_to_definition_name_lookup(pil_file); + let mut substitutions: BTreeMap<(String, PolyID), (String, PolyID)> = pil_file + .identities + .iter() + .filter_map(|id| { + if let Identity::Polynomial(PolynomialIdentity { expression, .. }) = id { + equal_constrained(expression, &poly_id_to_array_elem) + } else { + None + } + }) + .collect(); + + resolve_transitive_substitutions(&mut substitutions); + + let (subs_by_id, subs_by_name): (HashMap<_, _>, HashMap<_, _>) = substitutions + .iter() + .map(|(k, v)| ((k.1, v), (&k.0, v))) + .unzip(); + + pil_file.post_visit_expressions_in_identities_mut(&mut |e: &mut AlgebraicExpression<_>| { + if let AlgebraicExpression::Reference(ref mut reference) = e { + if let Some((replacement_name, replacement_id)) = subs_by_id.get(&reference.poly_id) { + reference.poly_id = *replacement_id; + reference.name = replacement_name.clone(); + } + } + }); + + pil_file.post_visit_expressions_mut(&mut |e: &mut Expression| { + if let Expression::Reference(_, Reference::Poly(reference)) = e { + if let Some((replacement_name, _)) = subs_by_name.get(&reference.name) { + reference.name = replacement_name.clone(); + } + } + }); +} + +fn resolve_transitive_substitutions(subs: &mut BTreeMap<(String, PolyID), (String, PolyID)>) { + let mut changed = true; + while changed { + changed = false; + let keys: Vec<_> = subs + .keys() + .map(|(name, id)| (name.to_string(), *id)) + .collect(); + + for key in keys { + let Some(target_key) = subs.get(&key) else { + continue; + }; + + let Some(new_target) = subs.get(target_key) else { + continue; + }; + + if subs.get(&key).unwrap() != new_target { + subs.insert(key, new_target.clone()); + changed = true; + } + } + } +} diff --git a/pilopt/tests/optimizer.rs b/pilopt/tests/optimizer.rs index 893f1c2f3c..6a39eb24db 100644 --- a/pilopt/tests/optimizer.rs +++ b/pilopt/tests/optimizer.rs @@ -19,12 +19,10 @@ fn replace_fixed() { "#; let expectation = r#"namespace N(65536); col witness X; - col witness Y; query |i| { let _: expr = 1_expr; }; - N::X = N::Y; - N::Y = 7 * N::X; + N::X = 7 * N::X; "#; let optimized = optimize(analyze_string::(input).unwrap()).to_string(); assert_eq!(optimized, expectation); @@ -116,12 +114,16 @@ fn intermediate() { let input = r#"namespace N(65536); col witness x; col intermediate = x; - intermediate = intermediate; + col int2 = intermediate * x; + col int3 = int2; + int3 = (3 * x) + x; "#; let expectation = r#"namespace N(65536); col witness x; col intermediate = N::x; - N::intermediate = N::intermediate; + col int2 = N::intermediate * N::x; + col int3 = N::int2; + N::int3 = 3 * N::x + N::x; "#; let optimized = optimize(analyze_string::(input).unwrap()).to_string(); assert_eq!(optimized, expectation); @@ -391,3 +393,75 @@ fn handle_array_references_in_prover_functions() { let optimized = optimize(analyze_string::(input).unwrap()).to_string(); assert_eq!(optimized, expectation); } + +#[test] +fn equal_constrained_array_elements_empty() { + let input = r#"namespace N(65536); + col witness w[20]; + w[4] = w[7]; + "#; + let expectation = r#"namespace N(65536); + col witness w[20]; + N::w[4] = N::w[7]; +"#; + let optimized = optimize(analyze_string::(input).unwrap()).to_string(); + assert_eq!(optimized, expectation); +} + +#[test] +fn equal_constrained_array_elements_query() { + let input = r#"namespace N(65536); + col witness w[20]; + w[4] = w[7]; + query |i| { + let _ = w[4] + w[7] - w[5]; + }; + "#; + let expectation = r#"namespace N(65536); + col witness w[20]; + N::w[4] = N::w[7]; + query |i| { + let _: expr = N::w[4_int] + N::w[7_int] - N::w[5_int]; + }; +"#; + let optimized = optimize(analyze_string::(input).unwrap()).to_string(); + assert_eq!(optimized, expectation); +} + +#[test] +fn equal_constrained_array_elements() { + let input = r#"namespace N(65536); + col witness w[20]; + col witness x; + w[4] = w[7]; + w[3] = w[5]; + x = w[3]; + w[7] + w[1] + x = 5; + "#; + let expectation = r#"namespace N(65536); + col witness w[20]; + N::w[4] = N::w[7]; + N::w[3] = N::w[5]; + N::w[7] + N::w[1] + N::w[3] = 5; +"#; + let optimized = optimize(analyze_string::(input).unwrap()).to_string(); + assert_eq!(optimized, expectation); +} + +#[test] +fn equal_constrained_transitive() { + let input = r#"namespace N(65536); + col witness a; + col witness b; + col witness c; + a = b; + b = c; + a + b + c = 5; + "#; + let expectation = r#"namespace N(65536); + col witness a; + N::a + N::a + N::a = 5; +"#; + let optimized = optimize(analyze_string::(input).unwrap()).to_string(); + assert_eq!(optimized, expectation); +} diff --git a/pipeline/tests/asm.rs b/pipeline/tests/asm.rs index 62b420e8d7..ca5d4b152d 100644 --- a/pipeline/tests/asm.rs +++ b/pipeline/tests/asm.rs @@ -372,12 +372,6 @@ fn full_pil_constant() { regular_test_all_fields(f, Default::default()); } -#[test] -fn intermediate() { - let f = "asm/intermediate.asm"; - regular_test_all_fields(f, Default::default()); -} - #[test] fn intermediate_nested() { let f = "asm/intermediate_nested.asm"; diff --git a/riscv-executor/src/lib.rs b/riscv-executor/src/lib.rs index 60d383d2ae..52c54cd02e 100644 --- a/riscv-executor/src/lib.rs +++ b/riscv-executor/src/lib.rs @@ -23,7 +23,7 @@ use builder::TraceBuilder; use itertools::Itertools; use powdr_ast::{ - analyzed::{Analyzed, Identity}, + analyzed::{AlgebraicExpression, Analyzed, Identity, LookupIdentity}, asm_analysis::{AnalysisASMFile, CallableSymbol, FunctionStatement, LabelStatement, Machine}, parsed::{ asm::{parse_absolute_path, AssignmentRegister, DebugDirective}, @@ -907,6 +907,14 @@ mod builder { } } + pub fn col_is_defined(&self, name: &str) -> bool { + if let ExecMode::Trace = self.mode { + self.trace.all_cols.contains(&name.to_string()) + } else { + false + } + } + pub fn push_row(&mut self, pc: u32) { if let ExecMode::Trace = self.mode { let new_len = self.trace.known_cols.len() + KnownWitnessCol::count(); @@ -1443,10 +1451,21 @@ impl Executor<'_, '_, F> { self.proc.backup_reg_mem(); - set_col!(X, get_fixed!(X_const)); - set_col!(Y, get_fixed!(Y_const)); - set_col!(Z, get_fixed!(Z_const)); - set_col!(W, get_fixed!(W_const)); + if self.proc.col_is_defined("main::X_const") { + set_col!(X, get_fixed!(X_const)); + } + + if self.proc.col_is_defined("main::Y_const") { + set_col!(Y, get_fixed!(Y_const)); + } + + if self.proc.col_is_defined("main::Z_const") { + set_col!(Z, get_fixed!(Z_const)); + } + + if self.proc.col_is_defined("main::W_const") { + set_col!(W, get_fixed!(W_const)); + } let instr = Instruction::from_name(name).expect("unknown instruction"); @@ -2905,23 +2924,36 @@ fn execute_inner( .unwrap_or_default(); // program columns to witness columns - let program_cols: HashMap<_, _> = if let Some(fixed) = &fixed { - fixed - .iter() - .filter_map(|(name, _col)| { - if !name.starts_with("main__rom::p_") { - return None; - } - let wit_name = format!("main::{}", name.strip_prefix("main__rom::p_").unwrap()); - if !witness_cols.contains(&wit_name) { - return None; - } - Some((name.clone(), wit_name)) - }) - .collect() - } else { - Default::default() - }; + let program_cols: HashMap<_, _> = opt_pil + .map(|pil| { + pil.identities + .iter() + .flat_map(|id| match id { + Identity::Lookup(LookupIdentity { left, right, .. }) => left + .expressions + .iter() + .zip(right.expressions.iter()) + .filter_map(|(l, r)| match (l, r) { + ( + AlgebraicExpression::Reference(l), + AlgebraicExpression::Reference(r), + ) => { + if r.name.starts_with("main__rom::p_") + && witness_cols.contains(&l.name) + { + Some((r.name.clone(), l.name.clone())) + } else { + None + } + } + _ => None, + }) + .collect::>(), + _ => vec![], + }) + .collect() + }) + .unwrap_or_default(); let proc = match TraceBuilder::<'_, F>::new( main_machine, diff --git a/test_data/asm/intermediate.asm b/test_data/asm/intermediate.asm deleted file mode 100644 index 4584503218..0000000000 --- a/test_data/asm/intermediate.asm +++ /dev/null @@ -1,11 +0,0 @@ -machine Intermediate with - latch: latch, - operation_id: operation_id, - degree: 8 -{ - col fixed latch = [1]*; - col fixed operation_id = [0]*; - col witness x; - col intermediate = x; - intermediate = intermediate; -} diff --git a/test_data/asm/set_hint.asm b/test_data/asm/set_hint.asm index cd25283d73..f52ee78412 100644 --- a/test_data/asm/set_hint.asm +++ b/test_data/asm/set_hint.asm @@ -7,5 +7,5 @@ let new_col_with_hint: -> expr = constr || { machine Main with degree: 4 { let x; let w = new_col_with_hint(); - x = w; + x = w + 1; } \ No newline at end of file