Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pilopt: optimize until fixpoint #2223

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ impl Children<Expression> for NamedType {
}
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PublicDeclaration {
pub id: u64,
pub source: SourceRef,
Expand All @@ -921,7 +921,9 @@ impl PublicDeclaration {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct SelectedExpressions<T> {
pub selector: AlgebraicExpression<T>,
pub expressions: Vec<AlgebraicExpression<T>>,
Expand All @@ -947,7 +949,7 @@ impl<T> Children<AlgebraicExpression<T>> for SelectedExpressions<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PolynomialIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -964,7 +966,7 @@ impl<T> Children<AlgebraicExpression<T>> for PolynomialIdentity<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct LookupIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -986,7 +988,7 @@ impl<T> Children<AlgebraicExpression<T>> for LookupIdentity<T> {
///
/// This identity is used as a replacement for a lookup identity which has been turned into challenge-based polynomial identities.
/// This is ignored by the backend.
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PhantomLookupIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand Down Expand Up @@ -1015,7 +1017,7 @@ impl<T> Children<AlgebraicExpression<T>> for PhantomLookupIdentity<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PermutationIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -1037,7 +1039,7 @@ impl<T> Children<AlgebraicExpression<T>> for PermutationIdentity<T> {
///
/// This identity is used as a replacement for a permutation identity which has been turned into challenge-based polynomial identities.
/// This is ignored by the backend.
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PhantomPermutationIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -1055,7 +1057,7 @@ impl<T> Children<AlgebraicExpression<T>> for PhantomPermutationIdentity<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct ConnectIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -1073,7 +1075,9 @@ impl<T> Children<AlgebraicExpression<T>> for ConnectIdentity<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, PartialOrd, Ord)]
#[derive(
Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, PartialOrd, Ord, Hash,
)]
pub struct ExpressionList<T>(pub Vec<AlgebraicExpression<T>>);

impl<T> Children<AlgebraicExpression<T>> for ExpressionList<T> {
Expand All @@ -1085,7 +1089,7 @@ impl<T> Children<AlgebraicExpression<T>> for ExpressionList<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PhantomBusInteractionIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand Down Expand Up @@ -1114,6 +1118,7 @@ impl<T> Children<AlgebraicExpression<T>> for PhantomBusInteractionIdentity<T> {
derive_more::Display,
derive_more::From,
derive_more::TryInto,
Hash,
)]
pub enum Identity<T> {
Polynomial(PolynomialIdentity<T>),
Expand Down Expand Up @@ -1305,7 +1310,9 @@ impl Hash for AlgebraicReference {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum AlgebraicExpression<T> {
Reference(AlgebraicReference),
PublicReference(String),
Expand All @@ -1315,7 +1322,9 @@ pub enum AlgebraicExpression<T> {
UnaryOperation(AlgebraicUnaryOperation<T>),
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct AlgebraicBinaryOperation<T> {
pub left: Box<AlgebraicExpression<T>>,
pub op: AlgebraicBinaryOperator,
Expand All @@ -1341,7 +1350,9 @@ impl<T> From<AlgebraicBinaryOperation<T>> for AlgebraicExpression<T> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct AlgebraicUnaryOperation<T> {
pub op: AlgebraicUnaryOperator,
pub expr: Box<AlgebraicExpression<T>>,
Expand Down Expand Up @@ -1525,7 +1536,7 @@ impl<T> AlgebraicExpression<T> {
}

#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema,
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct Challenge {
/// Challenge ID
Expand All @@ -1534,7 +1545,7 @@ pub struct Challenge {
}

#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema,
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum AlgebraicBinaryOperator {
Add,
Expand Down Expand Up @@ -1572,7 +1583,7 @@ impl TryFrom<BinaryOperator> for AlgebraicBinaryOperator {
}

#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema,
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum AlgebraicUnaryOperator {
Minus,
Expand Down
24 changes: 1 addition & 23 deletions backend/src/plonky3/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ mod tests {
col witness x;
col witness y;
col witness z;
y - 1 = 0;
y - 1 = 1;
x = 0;
x + y = z;

Expand Down Expand Up @@ -465,28 +465,6 @@ mod tests {
run_test(content);
}

#[test]
fn two_tables() {
// This test is a bit contrived but witgen wouldn't allow a more direct example
let content = r#"
namespace Add(8);
col witness x;
col witness y;
col witness z;
x = 0;
y = 0;
x + y = z;
1 $ [ x, y, z ] in 1 $ [ Mul::x, Mul::y, Mul::z ];

namespace Mul(16);
col witness x;
col witness y;
col witness z;
x * y = z;
"#;
run_test(content);
}

#[test]
fn challenge() {
let content = r#"
Expand Down
60 changes: 51 additions & 9 deletions pilopt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use std::cmp::Ordering;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::hash::{DefaultHasher, Hash, Hasher};

use itertools::Itertools;
use powdr_ast::analyzed::{
Expand All @@ -22,16 +23,20 @@ use referenced_symbols::{ReferencedSymbols, SymbolReference};

pub fn optimize<T: FieldElement>(mut pil_file: Analyzed<T>) -> Analyzed<T> {
let col_count_pre = (pil_file.commitment_count(), pil_file.constant_count());
let mut pil_hash = 0;
remove_unreferenced_definitions(&mut pil_file);
remove_constant_fixed_columns(&mut pil_file);
deduplicate_fixed_columns(&mut pil_file);
simplify_identities(&mut pil_file);
extract_constant_lookups(&mut pil_file);
remove_constant_witness_columns(&mut pil_file);
simplify_identities(&mut pil_file);
remove_trivial_identities(&mut pil_file);
remove_duplicate_identities(&mut pil_file);
remove_unreferenced_definitions(&mut pil_file);
while pil_hash != hash_pil_state(&pil_file) {
pil_hash = hash_pil_state(&pil_file);
remove_constant_fixed_columns(&mut pil_file);
deduplicate_fixed_columns(&mut pil_file);
simplify_identities(&mut pil_file);
extract_constant_lookups(&mut pil_file);
remove_constant_witness_columns(&mut pil_file);
simplify_identities(&mut pil_file);
remove_trivial_identities(&mut pil_file);
remove_duplicate_identities(&mut pil_file);
remove_unreferenced_definitions(&mut pil_file);
}
let col_count_post = (pil_file.commitment_count(), pil_file.constant_count());
log::info!(
"Removed {} witness and {} fixed columns. Total count now: {} witness and {} fixed columns.",
Expand All @@ -43,6 +48,43 @@ pub fn optimize<T: FieldElement>(mut pil_file: Analyzed<T>) -> Analyzed<T> {
pil_file
}

fn hash_pil_state<T: FieldElement>(pil: &Analyzed<T>) -> u64 {
let mut hasher = DefaultHasher::new();

for identity in &pil.identities {
identity.hash(&mut hasher);
}

let mut keys: Vec<_> = pil.definitions.keys().collect();
keys.sort();
for key in keys {
key.hash(&mut hasher);
if let Some(v) = &pil.definitions[key].1 {
v.hash(&mut hasher);
}
}

let mut keys: Vec<_> = pil.intermediate_columns.keys().collect();
keys.sort();
for key in keys {
key.hash(&mut hasher);
pil.intermediate_columns[key].1.hash(&mut hasher);
}

for pf in &pil.prover_functions {
pf.hash(&mut hasher);
}

let mut keys: Vec<_> = pil.public_declarations.keys().collect();
keys.sort();
for key in keys {
key.hash(&mut hasher);
pil.public_declarations[key].hash(&mut hasher);
}

hasher.finish()
}

/// Removes all definitions that are not referenced by an identity, public declaration
/// or witness column hint.
fn remove_unreferenced_definitions<T: FieldElement>(pil_file: &mut Analyzed<T>) {
Expand Down
Loading