Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pilopt: equal-constrained witness columns removal #2224

Merged
merged 13 commits into from
Jan 16, 2025
170 changes: 150 additions & 20 deletions pilopt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub fn optimize<T: FieldElement>(mut pil_file: Analyzed<T>) -> Analyzed<T> {
extract_constant_lookups(&mut pil_file);
remove_constant_witness_columns(&mut pil_file);
simplify_identities(&mut pil_file);
remove_equal_constrained_witness_columns(&mut pil_file);
remove_trivial_identities(&mut pil_file);
remove_duplicate_identities(&mut pil_file);
remove_unreferenced_definitions(&mut pil_file);
Expand Down Expand Up @@ -85,7 +86,7 @@ fn remove_unreferenced_definitions<T: FieldElement>(pil_file: &mut Analyzed<T>)
Box::new(value.iter().flat_map(|v| {
v.all_children().flat_map(|e| {
if let AlgebraicExpression::Reference(AlgebraicReference { poly_id, .. }) = e {
Some(poly_id_to_definition_name[poly_id].into())
Some(poly_id_to_definition_name[poly_id].0.into())
} else {
None
}
Expand Down Expand Up @@ -120,33 +121,41 @@ fn remove_unreferenced_definitions<T: FieldElement>(pil_file: &mut Analyzed<T>)
pil_file.remove_trait_impls(&impls_to_remove);
}

/// Builds a lookup-table that can be used to turn array elements
/// (in form of their poly ids) into the names of the arrays.
/// Builds a lookup-table that can be used to turn all poly ids into the names of the symbols that define them.
/// For array elements, this contains the array name and the index of the element in the array.
fn build_poly_id_to_definition_name_lookup(
pil_file: &Analyzed<impl FieldElement>,
) -> BTreeMap<PolyID, &String> {
) -> BTreeMap<PolyID, (&String, Option<usize>)> {
let mut poly_id_to_definition_name = BTreeMap::new();
#[allow(clippy::iter_over_hash_type)]
for (name, (symbol, _)) in &pil_file.definitions {
if matches!(symbol.kind, SymbolKind::Poly(_)) {
symbol.array_elements().for_each(|(_, id)| {
poly_id_to_definition_name.insert(id, name);
});
symbol
.array_elements()
.enumerate()
.for_each(|(idx, (_, id))| {
let array_pos = if symbol.is_array() { Some(idx) } else { None };
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
poly_id_to_definition_name.insert(id, (name, array_pos));
});
}
}
#[allow(clippy::iter_over_hash_type)]
for (name, (symbol, _)) in &pil_file.intermediate_columns {
symbol.array_elements().for_each(|(_, id)| {
poly_id_to_definition_name.insert(id, name);
});
symbol
.array_elements()
.enumerate()
.for_each(|(idx, (_, id))| {
let array_pos = if symbol.is_array() { Some(idx) } else { None };
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
poly_id_to_definition_name.insert(id, (name, array_pos));
});
}
poly_id_to_definition_name
}

/// Collect all names that are referenced in identities and public declarations.
fn collect_required_symbols<'a, T: FieldElement>(
pil_file: &'a Analyzed<T>,
poly_id_to_definition_name: &BTreeMap<PolyID, &'a String>,
poly_id_to_definition_name: &BTreeMap<PolyID, (&'a String, Option<usize>)>,
) -> HashSet<SymbolReference<'a>> {
let mut required_names: HashSet<SymbolReference<'a>> = Default::default();
required_names.extend(
Expand All @@ -165,7 +174,7 @@ fn collect_required_symbols<'a, T: FieldElement>(
for id in &pil_file.identities {
id.pre_visit_expressions(&mut |e: &AlgebraicExpression<T>| {
if let AlgebraicExpression::Reference(AlgebraicReference { poly_id, .. }) = e {
required_names.insert(poly_id_to_definition_name[poly_id].into());
required_names.insert(poly_id_to_definition_name[poly_id].0.into());
}
});
}
Expand Down Expand Up @@ -533,23 +542,44 @@ fn constrained_to_constant<T: FieldElement>(
None
}

/// Removes identities that evaluate to zero and lookups with empty columns.
/// Removes identities that evaluate to zero (including constraints of the form "X = X") and lookups with empty columns.
fn remove_trivial_identities<T: FieldElement>(pil_file: &mut Analyzed<T>) {
let to_remove = pil_file
.identities
.iter()
.enumerate()
.filter_map(|(index, identity)| match identity {
Identity::Polynomial(PolynomialIdentity { expression, .. }) => {
if let AlgebraicExpression::Number(n) = expression {
Identity::Polynomial(PolynomialIdentity { expression, .. }) => match expression {
AlgebraicExpression::Number(n) => {
if *n == 0.into() {
return Some(index);
Some(index)
} else {
// Otherwise the constraint is not satisfiable,
// but better to get the error elsewhere.
None
}
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
// Otherwise the constraint is not satisfiable,
// but better to get the error elsewhere.
}
None
}
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation {
left,
op: AlgebraicBinaryOperator::Sub,
right,
}) => {
if let (
AlgebraicExpression::Reference(left),
AlgebraicExpression::Reference(right),
) = (left.as_ref(), right.as_ref())
{
if !left.next && !right.next && left == right {
Some(index)
} else {
None
}
} else {
None
}
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
}
_ => None,
},
Identity::Lookup(LookupIdentity { left, right, .. })
| Identity::Permutation(PermutationIdentity { left, right, .. })
| Identity::PhantomLookup(PhantomLookupIdentity { left, right, .. })
Expand Down Expand Up @@ -688,3 +718,103 @@ fn remove_duplicate_identities<T: FieldElement>(pil_file: &mut Analyzed<T>) {
.collect();
pil_file.remove_identities(&to_remove);
}

/// Identifies witness columns that are directly constrained to be equal to other witness columns
/// through polynomial identities of the form "x = y" and returns a tuple ((name, id), (name, id))
/// for each pair of identified columns
fn equal_constrained<T: FieldElement>(
expression: &AlgebraicExpression<T>,
poly_id_to_array_elem: &BTreeMap<PolyID, (&String, Option<usize>)>,
) -> Option<((String, PolyID), (String, PolyID))> {
match expression {
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation {
left,
op: AlgebraicBinaryOperator::Sub,
right,
}) => match (left.as_ref(), right.as_ref()) {
(AlgebraicExpression::Reference(l), AlgebraicExpression::Reference(r)) => {
let is_valid = |x: &AlgebraicReference| {
x.is_witness()
&& !x.next
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
&& poly_id_to_array_elem.get(&x.poly_id).unwrap().1.is_none()
};

if is_valid(l) && is_valid(r) {
Some(if l.poly_id > r.poly_id {
((l.name.clone(), l.poly_id), (r.name.clone(), r.poly_id))
} else {
((r.name.clone(), r.poly_id), (l.name.clone(), l.poly_id))
})
} else {
None
}
}
_ => None,
},
_ => None,
}
}

fn remove_equal_constrained_witness_columns<T: FieldElement>(pil_file: &mut Analyzed<T>) {
let poly_id_to_array_elem = build_poly_id_to_definition_name_lookup(pil_file);
let substitutions: Vec<_> = pil_file
.identities
.iter()
.filter_map(|id| {
if let Identity::Polynomial(PolynomialIdentity { expression, .. }) = id {
equal_constrained(expression, &poly_id_to_array_elem)
} else {
None
}
})
.collect();

let substitutions = resolve_transitive_substitutions(substitutions);

let (subs_by_id, subs_by_name): (HashMap<_, _>, HashMap<_, _>) = substitutions
.iter()
.map(|((name, id), to_keep)| ((id, to_keep), (name, to_keep)))
.unzip();

pil_file.post_visit_expressions_in_identities_mut(&mut |e: &mut AlgebraicExpression<_>| {
if let AlgebraicExpression::Reference(ref mut reference) = e {
if let Some((replacement_name, replacement_id)) = subs_by_id.get(&reference.poly_id) {
reference.poly_id = *replacement_id;
reference.name = replacement_name.clone();
}
}
});

pil_file.post_visit_expressions_mut(&mut |e: &mut Expression| {
if let Expression::Reference(_, Reference::Poly(reference)) = e {
if let Some((replacement_name, _)) = subs_by_name.get(&reference.name) {
reference.name = replacement_name.clone();
}
}
});
}

fn resolve_transitive_substitutions(
subs: Vec<((String, PolyID), (String, PolyID))>,
) -> Vec<((String, PolyID), (String, PolyID))> {
let mut result = subs.clone();
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
let mut changed = true;

while changed {
changed = false;
for i in 0..result.len() {
let (_, target1) = &result[i].1;
if let Some(j) = result
.iter()
.position(|((_, source2), _)| source2 == target1)
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
{
let ((name1, source1), _) = &result[i];
let (_, (name3, target2)) = &result[j];
result[i] = ((name1.clone(), *source1), (name3.clone(), *target2));
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
changed = true;
}
}
}

result
}
82 changes: 77 additions & 5 deletions pilopt/tests/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@ fn replace_fixed() {
"#;
let expectation = r#"namespace N(65536);
col witness X;
col witness Y;
query |i| {
let _: expr = 1_expr;
};
N::X = N::Y;
N::Y = 7 * N::X;
N::X = 7 * N::X;
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
Expand Down Expand Up @@ -100,12 +98,16 @@ fn intermediate() {
let input = r#"namespace N(65536);
col witness x;
col intermediate = x;
intermediate = intermediate;
col int2 = intermediate * x;
col int3 = int2;
int3 = (3 * x) + x;
"#;
let expectation = r#"namespace N(65536);
col witness x;
col intermediate = N::x;
N::intermediate = N::intermediate;
col int2 = N::intermediate * N::x;
col int3 = N::int2;
N::int3 = 3 * N::x + N::x;
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
Expand Down Expand Up @@ -375,3 +377,73 @@ fn handle_array_references_in_prover_functions() {
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
}

#[test]
fn equal_constrained_array_elements_empty() {
let input = r#"namespace N(65536);
col witness w[20];
w[4] = w[7];
"#;
let expectation = r#"namespace N(65536);
col witness w[20];
N::w[4] = N::w[7];
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
}

#[test]
fn equal_constrained_array_elements_query() {
let input = r#"namespace N(65536);
col witness w[20];
w[4] = w[7];
query |i| {
let _ = w[4] + w[7] - w[5];
};
"#;
let expectation = r#"namespace N(65536);
col witness w[20];
N::w[4] = N::w[7];
query |i| {
let _: expr = N::w[4_int] + N::w[7_int] - N::w[5_int];
};
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
}

#[test]
fn equal_constrained_array_elements() {
let input = r#"namespace N(65536);
col witness w[20];
w[4] = w[7];
w[3] = w[5];
w[7] + w[1] = 5;
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
"#;
let expectation = r#"namespace N(65536);
col witness w[20];
N::w[4] = N::w[7];
N::w[3] = N::w[5];
N::w[7] + N::w[1] = 5;
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
}

#[test]
fn equal_constrained_transitive() {
let input = r#"namespace N(65536);
col witness a;
col witness b;
col witness c;
a = b;
b = c;
a + b + c = 5;
"#;
let expectation = r#"namespace N(65536);
col witness a;
N::a + N::a + N::a = 5;
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
}
6 changes: 0 additions & 6 deletions pipeline/tests/asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,12 +366,6 @@ fn full_pil_constant() {
regular_test_all_fields(f, Default::default());
}

#[test]
fn intermediate() {
let f = "asm/intermediate.asm";
regular_test_all_fields(f, Default::default());
}

#[test]
fn intermediate_nested() {
let f = "asm/intermediate_nested.asm";
Expand Down
Loading
Loading