diff --git a/solver/src/reasoners/cp/mod.rs b/solver/src/reasoners/cp/mod.rs index 58b149ae..94517fca 100644 --- a/solver/src/reasoners/cp/mod.rs +++ b/solver/src/reasoners/cp/mod.rs @@ -11,7 +11,7 @@ use crate::reasoners::{Contradiction, ReasonerId, Theory}; use anyhow::Context; use num_integer::{div_ceil, div_floor}; use std::cmp::Ordering; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; // =========== Sum =========== @@ -68,17 +68,16 @@ impl std::fmt::Display for LinearSumLeq { } impl LinearSumLeq { - fn get_lower_bound(&self, elem: SumElem, domains: &Domains) -> IntCst { + fn get_lower_bound(&self, elem: SumElem, domains: &Domains) -> i64 { let var_is_present = domains.present(elem.var) == Some(true); debug_assert!(!domains.entails(elem.lit) || var_is_present); let int_part = match elem.factor.cmp(&0) { - Ordering::Less => domains.ub(elem.var), + Ordering::Less => domains.ub(elem.var) as i64, Ordering::Equal => 0, - Ordering::Greater => domains.lb(elem.var), + Ordering::Greater => domains.lb(elem.var) as i64, } - .saturating_mul(elem.factor) - .clamp(INT_CST_MIN, INT_CST_MAX); + .saturating_mul(elem.factor as i64); match domains.value(elem.lit) { Some(true) => int_part, @@ -86,17 +85,16 @@ impl LinearSumLeq { None => 0.min(int_part), } } - fn get_upper_bound(&self, elem: SumElem, domains: &Domains) -> IntCst { + fn get_upper_bound(&self, elem: SumElem, domains: &Domains) -> i64 { let var_is_present = domains.present(elem.var) == Some(true); debug_assert!(!domains.entails(elem.lit) || var_is_present); let int_part = match elem.factor.cmp(&0) { - Ordering::Less => domains.lb(elem.var), + Ordering::Less => domains.lb(elem.var) as i64, Ordering::Equal => 0, - Ordering::Greater => domains.ub(elem.var), + Ordering::Greater => domains.ub(elem.var) as i64, } - .saturating_mul(elem.factor) - .clamp(INT_CST_MIN, INT_CST_MAX); + .saturating_mul(elem.factor as i64); match domains.value(elem.lit) { Some(true) => int_part, @@ -104,15 +102,11 @@ impl LinearSumLeq { None => 0.max(int_part), } } - fn set_ub(&self, elem: SumElem, ub: IntCst, domains: &mut Domains, cause: Cause) -> Result { + fn set_ub(&self, elem: SumElem, ub: i64, domains: &mut Domains, cause: Cause) -> Result { + // convert back to i32 (which is used for domains). Clamping should prevent any conversion error. + let ub = ub.clamp(INT_CST_MIN as i64, INT_CST_MAX as i64) as i32; debug_assert!(!domains.entails(elem.lit) || domains.present(elem.var) == Some(true)); - // println!( - // " {:?} : [{}, {}] ub: {ub} -> {}", - // var, - // domains.lb(var), - // domains.ub(var), - // ub / elem.factor, - // ); + match elem.factor.cmp(&0) { Ordering::Less => domains.set_lb(elem.var, div_ceil(ub, elem.factor), cause), Ordering::Equal => unreachable!(), @@ -137,18 +131,13 @@ impl LinearSumLeq { impl Propagator for LinearSumLeq { fn setup(&self, id: PropagatorId, context: &mut Watches) { - // println!("SET UP"); - + context.add_watch(self.active.variable(), id); for e in &self.elements { if !e.is_constant() { - match e.factor.cmp(&0) { - Ordering::Less => context.add_watch(SignedVar::plus(e.var), id), - Ordering::Equal => {} - Ordering::Greater => context.add_watch(SignedVar::minus(e.var), id), - } + context.add_watch(e.var, id); } if e.lit != Lit::TRUE { - context.add_watch(e.lit.svar(), id); + context.add_watch(e.lit.svar().variable(), id); } } } @@ -160,28 +149,25 @@ impl Propagator for LinearSumLeq { .iter() .copied() .filter(|e| !domains.entails(!e.lit)) - .map(|e| self.get_lower_bound(e, domains) as i64) + .map(|e| self.get_lower_bound(e, domains)) .sum(); let f = (self.ub as i64) - sum_lb; - // println!("Propagation : {} <= {}", sum_lb, self.ub); - // self.print(domains); + if f < 0 { - // println!("INCONSISTENT"); + // INCONSISTENT let mut expl = Explanation::new(); self.explain(Lit::FALSE, domains, &mut expl); return Err(Contradiction::Explanation(expl)); } for &e in &self.elements { - let lb = self.get_lower_bound(e, domains) as i64; - let ub = self.get_upper_bound(e, domains) as i64; + let lb = self.get_lower_bound(e, domains); + let ub = self.get_upper_bound(e, domains); debug_assert!(lb <= ub); if ub - lb > f { - // println!(" problem on: {e:?} {lb} {ub}"); - // NOTE: Conversion from i64 to i32 should not fail due to the clamp between two i32 values. - let new_ub = (f + lb).clamp(INT_CST_MIN as i64, INT_CST_MAX as i64) as i32; + let new_ub = f + lb; match self.set_ub(e, new_ub, domains, cause) { - Ok(true) => {} // println!(" propagated: {e:?} <= {}", f + lb), - Ok(false) => {} + Ok(true) => {} // domain updated + Ok(false) => {} // no-op Err(err) => { // If the update is invalid, a solution could be to force the element to not be present. if !domains.entails(e.lit) { @@ -199,9 +185,6 @@ impl Propagator for LinearSumLeq { } } } - // println!("AFTER PROP"); - // self.print(domains); - // println!(); Ok(()) } @@ -228,8 +211,6 @@ impl Propagator for LinearSumLeq { } } } - // dbg!(&self); - // dbg!(&out_explanation.lits); } fn clone_box(&self) -> Box { @@ -277,19 +258,19 @@ impl From for DynPropagator { #[derive(Clone, Default)] pub struct Watches { - propagations: HashMap>, + propagations: HashMap>, empty: [PropagatorId; 0], } impl Watches { - fn add_watch(&mut self, watched: SignedVar, propagator_id: PropagatorId) { + fn add_watch(&mut self, watched: VarRef, propagator_id: PropagatorId) { self.propagations .entry(watched) .or_insert_with(|| Vec::with_capacity(4)) .push(propagator_id) } - fn get(&self, var_bound: SignedVar) -> &[PropagatorId] { + fn get(&self, var_bound: VarRef) -> &[PropagatorId] { self.propagations .get(&var_bound) .map(|v| v.as_slice()) @@ -355,21 +336,30 @@ impl Theory for Cp { } fn propagate(&mut self, domains: &mut Domains) -> Result<(), Contradiction> { - // TODO: at this point, all propagators are invoked regardless of watches - // if self.saved == DecLvl::ROOT { - for (id, p) in self.constraints.entries() { - let cause = self.id.cause(id); - p.constraint.propagate(domains, cause)?; + // list of all propagators to trigger + let mut to_propagate = HashSet::with_capacity(64); + + // in first propagation, mark everything for propagation + // NOte: this is might actually be trigger multiple times when going back to the root + if self.saved == DecLvl::ROOT { + for (id, p) in self.constraints.entries() { + to_propagate.insert(id); + } + } + + // add any propagator that watched a changed variable since last propagation + while let Some(event) = self.model_events.pop(domains.trail()).copied() { + let watchers = self.watches.get(event.affected_bound.variable()); + for &watcher in watchers { + to_propagate.insert(watcher); + } + } + + for propagator in to_propagate { + let constraint = self.constraints[propagator].constraint.as_ref(); + let cause = self.id.cause(propagator); + constraint.propagate(domains, cause)?; } - // } - // while let Some(event) = self.model_events.pop(domains.trail()).copied() { - // let watchers = self.watches.get(event.affected_bound); - // for &watcher in watchers { - // let constraint = self.constraints[watcher].constraint.as_ref(); - // let cause = self.id.cause(watcher); - // constraint.propagate(&event, domains, cause)?; - // } - // } Ok(()) } @@ -403,40 +393,6 @@ impl Backtrack for Cp { } } -// impl BindSplit for Cp { -// fn enforce_true(&mut self, expr: &Expr, _doms: &mut Domains) -> BindingResult { -// if let Some(leq) = downcast::(expr) { -// let elements = leq -// .sum -// .iter() -// .map(|e| SumElem { -// factor: e.factor, -// var: e.var, -// or_zero: e.or_zero, -// }) -// .collect(); -// let propagator = LinearSumLeq { -// elements, -// ub: leq.upper_bound, -// }; -// self.add_propagator(propagator); -// BindingResult::Enforced -// } else { -// BindingResult::Unsupported -// } -// } -// -// fn enforce_false(&mut self, _expr: &Expr, _doms: &mut Domains) -> BindingResult { -// // TODO -// BindingResult::Unsupported -// } -// -// fn enforce_eq(&mut self, _literal: Lit, _expr: &Expr, _doms: &mut Domains) -> BindingResult { -// // TODO -// BindingResult::Unsupported -// } -// } - /* ========================================================================== */ /* Tests */ /* ========================================================================== */ @@ -469,8 +425,8 @@ mod tests { /* =============================== Helpers ============================== */ fn check_bounds(s: &LinearSumLeq, e: SumElem, d: &Domains, lb: IntCst, ub: IntCst) { - assert_eq!(s.get_lower_bound(e, d), lb); - assert_eq!(s.get_upper_bound(e, d), ub); + assert_eq!(s.get_lower_bound(e, d), lb.into()); + assert_eq!(s.get_upper_bound(e, d), ub.into()); } fn check_bounds_var(v: VarRef, d: &Domains, lb: IntCst, ub: IntCst) {