From feec718783f777fed4938c591226170d0c4012b1 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 15 Jan 2025 11:28:53 +0100 Subject: [PATCH] address review comments --- backend/src/mock/bus_checker.rs | 96 +++++++++++++++++++++------------ backend/src/mock/mod.rs | 6 +-- 2 files changed, 65 insertions(+), 37 deletions(-) diff --git a/backend/src/mock/bus_checker.rs b/backend/src/mock/bus_checker.rs index 8a5d8eb543..8339c1972d 100644 --- a/backend/src/mock/bus_checker.rs +++ b/backend/src/mock/bus_checker.rs @@ -1,23 +1,25 @@ use std::{cmp::Ordering, collections::BTreeMap, fmt}; +use itertools::Itertools; use powdr_ast::{ analyzed::{Analyzed, Identity, PhantomBusInteractionIdentity}, parsed::visitor::Children, }; use powdr_executor_utils::expression_evaluator::ExpressionEvaluator; use powdr_number::FieldElement; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use super::{localize, machine::Machine, unique_referenced_namespaces}; pub struct BusChecker<'a, F> { - connections: &'a [BusConnection], + connections: &'a [BusInteraction], machines: &'a BTreeMap>, } pub struct Error { tuple: Vec, - sends: BTreeMap, usize>, - receives: BTreeMap, usize>, + sends: BTreeMap, usize>, + receives: BTreeMap, usize>, } impl fmt::Display for Error { @@ -26,15 +28,10 @@ impl fmt::Display for Error { f, "Bus interaction {} failed for tuple {}:", self.tuple[0], - self.tuple - .iter() - .skip(1) - .map(|v| v.to_string()) - .collect::>() - .join(", ") + self.tuple.iter().skip(1).map(|v| v.to_string()).join(", ") )?; for ( - BusConnection { + BusInteraction { machine, interaction, }, @@ -46,7 +43,7 @@ impl fmt::Display for Error { writeln!(f, " in {machine}",)?; } for ( - BusConnection { + BusInteraction { machine, interaction, }, @@ -62,12 +59,12 @@ impl fmt::Display for Error { } #[derive(PartialEq, PartialOrd, Eq, Ord, Clone)] -pub struct BusConnection { +pub struct BusInteraction { pub machine: String, pub interaction: PhantomBusInteractionIdentity, } -impl BusConnection { +impl BusInteraction { /// Extracts all bus connections from the global PIL. pub fn get_all( global_pil: &Analyzed, @@ -85,10 +82,10 @@ impl BusConnection { } }) .map(|interaction| { - // Localize the interaction assuming a single namespace is accessed. TODO: This may break due to the latch. + // Localize the interaction assuming a single namespace is accessed. let machine = unique_referenced_namespaces(&interaction).unwrap(); let interaction = localize(interaction, global_pil, &machine_to_pil[&machine]); - BusConnection { + BusInteraction { machine, interaction, } @@ -99,7 +96,7 @@ impl BusConnection { impl<'a, F: FieldElement> BusChecker<'a, F> { pub fn new( - connections: &'a [BusConnection], + connections: &'a [BusInteraction], machines: &'a BTreeMap>, ) -> Self { Self { @@ -109,19 +106,19 @@ impl<'a, F: FieldElement> BusChecker<'a, F> { } pub fn check(&self) -> Result<(), Vec>> { - type BusState<'a, F> = BTreeMap< - Vec, - ( - BTreeMap<&'a BusConnection, usize>, - BTreeMap<&'a BusConnection, usize>, - ), - >; + #[derive(Default)] + struct TupleState<'a, F> { + sends: BTreeMap<&'a BusInteraction, usize>, + receives: BTreeMap<&'a BusInteraction, usize>, + } + + type BusState<'a, F> = BTreeMap, TupleState<'a, F>>; let bus_state: BusState = self .machines - .iter() + .into_par_iter() .flat_map(|(name, machine)| { - (0..machine.size).flat_map(|row_id| { + (0..machine.size).into_par_iter().flat_map(|row_id| { // create an evaluator for this row let mut evaluator = ExpressionEvaluator::new( machine.values.row(row_id), @@ -148,14 +145,13 @@ impl<'a, F: FieldElement> BusChecker<'a, F> { (bus_connection, tuple, multiplicity) }) + .collect::>() }) }) + // fold the interactions from a row into a state .fold( - Default::default(), - |mut counts, (bus_connection, tuple, multiplicity)| { - // update the counts - let (s, r) = counts.entry(tuple).or_default(); - + BusState::default, + |mut state, (bus_connection, tuple, multiplicity)| { let abs = multiplicity.unsigned_abs() as usize; match multiplicity.cmp(&0) { @@ -163,25 +159,57 @@ impl<'a, F: FieldElement> BusChecker<'a, F> { Ordering::Equal => {} // if the multiplicity is positive, we send Ordering::Greater => { - s.entry(bus_connection) + let TupleState { sends, .. } = state.entry(tuple).or_default(); + sends + .entry(bus_connection) .and_modify(|sends| *sends += abs) .or_insert(abs); } // if the multiplicity is negative, we receive Ordering::Less => { - r.entry(bus_connection) + let TupleState { receives, .. } = state.entry(tuple).or_default(); + receives + .entry(bus_connection) .and_modify(|receives| *receives += abs) .or_insert(abs); } } - counts + state + }, + ) + // combine all the states to one + .reduce( + BusState::default, + |mut a, b| { + for (tuple, TupleState { sends, receives }) in b { + let TupleState { + sends: sends_a, + receives: receives_a, + } = a.entry(tuple).or_default(); + + for (bus_connection, count) in sends { + sends_a + .entry(bus_connection) + .and_modify(|sends| *sends += count) + .or_insert(count); + } + + for (bus_connection, count) in receives { + receives_a + .entry(bus_connection) + .and_modify(|receives| *receives += count) + .or_insert(count); + } + } + + a }, ); let mut errors = vec![]; - for (tuple, (sends, receives)) in bus_state { + for (tuple, TupleState { sends, receives }) in bus_state { let send_count = sends.values().sum::(); let receive_count = receives.values().sum::(); if send_count != receive_count { diff --git a/backend/src/mock/mod.rs b/backend/src/mock/mod.rs index e9908603a4..fd443f0e4c 100644 --- a/backend/src/mock/mod.rs +++ b/backend/src/mock/mod.rs @@ -6,7 +6,7 @@ use std::{ sync::Arc, }; -use bus_checker::{BusChecker, BusConnection}; +use bus_checker::{BusChecker, BusInteraction}; use connection_constraint_checker::{Connection, ConnectionConstraintChecker}; use machine::Machine; use polynomial_constraint_checker::PolynomialConstraintChecker; @@ -50,7 +50,7 @@ impl BackendFactory for MockBackendFactory { } let machine_to_pil = powdr_backend_utils::split_pil(&pil); let connections = Connection::get_all(&pil, &machine_to_pil); - let bus_connections = BusConnection::get_all(&pil, &machine_to_pil); + let bus_connections = BusInteraction::get_all(&pil, &machine_to_pil); Ok(Box::new(MockBackend { machine_to_pil, @@ -69,7 +69,7 @@ pub(crate) struct MockBackend { machine_to_pil: BTreeMap>, fixed: Arc)>>, connections: Vec>, - bus_connections: Vec>, + bus_connections: Vec>, } impl Backend for MockBackend {