Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cp): Improve efficiency of linear constraints propagation #114

Merged
merged 7 commits into from
Nov 24, 2023
148 changes: 52 additions & 96 deletions solver/src/reasoners/cp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::reasoners::{Contradiction, ReasonerId, Theory};
use anyhow::Context;
use num_integer::{div_ceil, div_floor};
use std::cmp::Ordering;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

// =========== Sum ===========

Expand Down Expand Up @@ -68,51 +68,45 @@ impl std::fmt::Display for LinearSumLeq {
}

impl LinearSumLeq {
fn get_lower_bound(&self, elem: SumElem, domains: &Domains) -> IntCst {
fn get_lower_bound(&self, elem: SumElem, domains: &Domains) -> i64 {
let var_is_present = domains.present(elem.var) == Some(true);
debug_assert!(!domains.entails(elem.lit) || var_is_present);

let int_part = match elem.factor.cmp(&0) {
Ordering::Less => domains.ub(elem.var),
Ordering::Less => domains.ub(elem.var) as i64,
Ordering::Equal => 0,
Ordering::Greater => domains.lb(elem.var),
Ordering::Greater => domains.lb(elem.var) as i64,
}
.saturating_mul(elem.factor)
.clamp(INT_CST_MIN, INT_CST_MAX);
.saturating_mul(elem.factor as i64);

match domains.value(elem.lit) {
Some(true) => int_part,
Some(false) => 0,
None => 0.min(int_part),
}
}
fn get_upper_bound(&self, elem: SumElem, domains: &Domains) -> IntCst {
fn get_upper_bound(&self, elem: SumElem, domains: &Domains) -> i64 {
let var_is_present = domains.present(elem.var) == Some(true);
debug_assert!(!domains.entails(elem.lit) || var_is_present);

let int_part = match elem.factor.cmp(&0) {
Ordering::Less => domains.lb(elem.var),
Ordering::Less => domains.lb(elem.var) as i64,
Ordering::Equal => 0,
Ordering::Greater => domains.ub(elem.var),
Ordering::Greater => domains.ub(elem.var) as i64,
}
.saturating_mul(elem.factor)
.clamp(INT_CST_MIN, INT_CST_MAX);
.saturating_mul(elem.factor as i64);

match domains.value(elem.lit) {
Some(true) => int_part,
Some(false) => 0,
None => 0.max(int_part),
}
}
fn set_ub(&self, elem: SumElem, ub: IntCst, domains: &mut Domains, cause: Cause) -> Result<bool, InvalidUpdate> {
fn set_ub(&self, elem: SumElem, ub: i64, domains: &mut Domains, cause: Cause) -> Result<bool, InvalidUpdate> {
// convert back to i32 (which is used for domains). Clamping should prevent any conversion error.
let ub = ub.clamp(INT_CST_MIN as i64, INT_CST_MAX as i64) as i32;
debug_assert!(!domains.entails(elem.lit) || domains.present(elem.var) == Some(true));
// println!(
// " {:?} : [{}, {}] ub: {ub} -> {}",
// var,
// domains.lb(var),
// domains.ub(var),
// ub / elem.factor,
// );

match elem.factor.cmp(&0) {
Ordering::Less => domains.set_lb(elem.var, div_ceil(ub, elem.factor), cause),
Ordering::Equal => unreachable!(),
Expand All @@ -137,18 +131,13 @@ impl LinearSumLeq {

impl Propagator for LinearSumLeq {
fn setup(&self, id: PropagatorId, context: &mut Watches) {
// println!("SET UP");

context.add_watch(self.active.variable(), id);
for e in &self.elements {
if !e.is_constant() {
match e.factor.cmp(&0) {
Ordering::Less => context.add_watch(SignedVar::plus(e.var), id),
Ordering::Equal => {}
Ordering::Greater => context.add_watch(SignedVar::minus(e.var), id),
}
context.add_watch(e.var, id);
}
if e.lit != Lit::TRUE {
context.add_watch(e.lit.svar(), id);
context.add_watch(e.lit.svar().variable(), id);
}
}
}
Expand All @@ -160,28 +149,25 @@ impl Propagator for LinearSumLeq {
.iter()
.copied()
.filter(|e| !domains.entails(!e.lit))
.map(|e| self.get_lower_bound(e, domains) as i64)
.map(|e| self.get_lower_bound(e, domains))
.sum();
let f = (self.ub as i64) - sum_lb;
// println!("Propagation : {} <= {}", sum_lb, self.ub);
// self.print(domains);

if f < 0 {
// println!("INCONSISTENT");
// INCONSISTENT
let mut expl = Explanation::new();
self.explain(Lit::FALSE, domains, &mut expl);
return Err(Contradiction::Explanation(expl));
}
for &e in &self.elements {
let lb = self.get_lower_bound(e, domains) as i64;
let ub = self.get_upper_bound(e, domains) as i64;
let lb = self.get_lower_bound(e, domains);
let ub = self.get_upper_bound(e, domains);
debug_assert!(lb <= ub);
if ub - lb > f {
// println!(" problem on: {e:?} {lb} {ub}");
// NOTE: Conversion from i64 to i32 should not fail due to the clamp between two i32 values.
let new_ub = (f + lb).clamp(INT_CST_MIN as i64, INT_CST_MAX as i64) as i32;
let new_ub = f + lb;
match self.set_ub(e, new_ub, domains, cause) {
Ok(true) => {} // println!(" propagated: {e:?} <= {}", f + lb),
Ok(false) => {}
Ok(true) => {} // domain updated
Ok(false) => {} // no-op
Err(err) => {
// If the update is invalid, a solution could be to force the element to not be present.
if !domains.entails(e.lit) {
Expand All @@ -199,9 +185,6 @@ impl Propagator for LinearSumLeq {
}
}
}
// println!("AFTER PROP");
// self.print(domains);
// println!();
Ok(())
}

Expand All @@ -228,8 +211,6 @@ impl Propagator for LinearSumLeq {
}
}
}
// dbg!(&self);
// dbg!(&out_explanation.lits);
}

fn clone_box(&self) -> Box<dyn Propagator> {
Expand Down Expand Up @@ -277,19 +258,19 @@ impl<T: Propagator + 'static> From<T> for DynPropagator {

#[derive(Clone, Default)]
pub struct Watches {
propagations: HashMap<SignedVar, Vec<PropagatorId>>,
propagations: HashMap<VarRef, Vec<PropagatorId>>,
empty: [PropagatorId; 0],
}

impl Watches {
fn add_watch(&mut self, watched: SignedVar, propagator_id: PropagatorId) {
fn add_watch(&mut self, watched: VarRef, propagator_id: PropagatorId) {
self.propagations
.entry(watched)
.or_insert_with(|| Vec::with_capacity(4))
.push(propagator_id)
}

fn get(&self, var_bound: SignedVar) -> &[PropagatorId] {
fn get(&self, var_bound: VarRef) -> &[PropagatorId] {
self.propagations
.get(&var_bound)
.map(|v| v.as_slice())
Expand Down Expand Up @@ -355,21 +336,30 @@ impl Theory for Cp {
}

fn propagate(&mut self, domains: &mut Domains) -> Result<(), Contradiction> {
// TODO: at this point, all propagators are invoked regardless of watches
// if self.saved == DecLvl::ROOT {
for (id, p) in self.constraints.entries() {
let cause = self.id.cause(id);
p.constraint.propagate(domains, cause)?;
// list of all propagators to trigger
let mut to_propagate = HashSet::with_capacity(64);

// in first propagation, mark everything for propagation
// NOte: this is might actually be trigger multiple times when going back to the root
if self.saved == DecLvl::ROOT {
for (id, p) in self.constraints.entries() {
to_propagate.insert(id);
}
}

// add any propagator that watched a changed variable since last propagation
while let Some(event) = self.model_events.pop(domains.trail()).copied() {
let watchers = self.watches.get(event.affected_bound.variable());
for &watcher in watchers {
to_propagate.insert(watcher);
}
}

for propagator in to_propagate {
let constraint = self.constraints[propagator].constraint.as_ref();
let cause = self.id.cause(propagator);
constraint.propagate(domains, cause)?;
}
// }
// while let Some(event) = self.model_events.pop(domains.trail()).copied() {
// let watchers = self.watches.get(event.affected_bound);
// for &watcher in watchers {
// let constraint = self.constraints[watcher].constraint.as_ref();
// let cause = self.id.cause(watcher);
// constraint.propagate(&event, domains, cause)?;
// }
// }
Ok(())
}

Expand Down Expand Up @@ -403,40 +393,6 @@ impl Backtrack for Cp {
}
}

// impl BindSplit for Cp {
// fn enforce_true(&mut self, expr: &Expr, _doms: &mut Domains) -> BindingResult {
// if let Some(leq) = downcast::<NFLinearLeq>(expr) {
// let elements = leq
// .sum
// .iter()
// .map(|e| SumElem {
// factor: e.factor,
// var: e.var,
// or_zero: e.or_zero,
// })
// .collect();
// let propagator = LinearSumLeq {
// elements,
// ub: leq.upper_bound,
// };
// self.add_propagator(propagator);
// BindingResult::Enforced
// } else {
// BindingResult::Unsupported
// }
// }
//
// fn enforce_false(&mut self, _expr: &Expr, _doms: &mut Domains) -> BindingResult {
// // TODO
// BindingResult::Unsupported
// }
//
// fn enforce_eq(&mut self, _literal: Lit, _expr: &Expr, _doms: &mut Domains) -> BindingResult {
// // TODO
// BindingResult::Unsupported
// }
// }

/* ========================================================================== */
/* Tests */
/* ========================================================================== */
Expand Down Expand Up @@ -469,8 +425,8 @@ mod tests {
/* =============================== Helpers ============================== */

fn check_bounds(s: &LinearSumLeq, e: SumElem, d: &Domains, lb: IntCst, ub: IntCst) {
assert_eq!(s.get_lower_bound(e, d), lb);
assert_eq!(s.get_upper_bound(e, d), ub);
assert_eq!(s.get_lower_bound(e, d), lb.into());
assert_eq!(s.get_upper_bound(e, d), ub.into());
}

fn check_bounds_var(v: VarRef, d: &Domains, lb: IntCst, ub: IntCst) {
Expand Down
Loading