Skip to content

Commit

Permalink
refactor(expl): change names and use iter().cloned() instead of `.c…
Browse files Browse the repository at this point in the history
…lone()`
  • Loading branch information
nrealus committed Jan 3, 2025
1 parent 18948aa commit 6af6bd8
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 63 deletions.
18 changes: 7 additions & 11 deletions explainability/src/explain/explanation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@ use std::sync::Arc;
use aries::core::Lit;
use aries::model::{Label, Model};

// "Essence" vs "Counterfactual" ? "Premise" ?
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ExplEssence(pub BTreeSet<Lit>, pub BTreeSet<Lit>);
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Essence(pub BTreeSet<Lit>, pub BTreeSet<Lit>);

// support (best alternative) ? justification ? argument ? cause ?
// "contradiction" vs "modelling" ? (but a counterexample could also be seen as one ?)
// just "example" vs "counterexample", maybe ?
#[derive(Debug, PartialEq, Eq, Hash)]
pub enum ExplSubstance {
Modelling(BTreeSet<Lit>),
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Substance {
ModelConstraints(BTreeSet<Lit>),
CounterExample(BTreeSet<Lit>),
}

Expand All @@ -37,8 +33,8 @@ impl ExplanationFilter {

pub struct Explanation<Lbl: Label> {
pub models: Vec<Arc<Model<Lbl>>>,
pub essences: Vec<ExplEssence>,
pub substances: Vec<ExplSubstance>,
pub essences: Vec<Essence>,
pub substances: Vec<Substance>,
pub table: BTreeMap<(EssenceIndex, SubstanceIndex), BTreeSet<ModelIndex>>,
pub filter: ExplanationFilter,
}
71 changes: 39 additions & 32 deletions explainability/src/explain/presupposition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,36 @@ pub fn check_presupposition<Lbl: Label>(
cached_solver: Option<&mut Solver<Lbl>>,
) -> Result<(), UnmetPresupposition<Lbl>> {
let solver = if let Some(s) = cached_solver {
// If we (the caller of the function) have supplied a cached solver to use, then use it.
s
} else {
&mut create_solver((*presupposition.model).clone())
// If no cached solver has been supplied, then create one and use it.
&mut {
let model = (*presupposition.model).clone();
let stn_config = StnConfig {
theory_propagation: TheoryPropagationLevel::Full,
..Default::default()
};
let mut solver = Solver::<Lbl>::new(model);
solver.reasoners.diff.config = stn_config;
solver
}
};
if skip_model_situ_sat_check {

if !skip_model_situ_sat_check {
// We need to make sure `model` /\ `situ` is indeed SAT.
match solver.solve_with_assumptions(presupposition.situ.iter().cloned()) {
Ok(_) => solver.restore(DecLvl::from(presupposition.situ.len())),
Err(_) => {
return Err(UnmetPresupposition {
presupposition,
cause: UnmetPresuppositionCause::ModelSituUnsat,
})
}
}
} else {
// If we (the caller of the function) want to skip checking `model` /\ `situ` is SAT
// (because we know that it's already the case), we only do the initial propagation and assumptions.
debug_assert!(solver.current_decision_level() == DecLvl::ROOT);
match solver.propagate_and_backtrack_to_consistent(solver.current_decision_level()) {
Ok(_) => (), // expected,
Expand All @@ -69,27 +94,18 @@ pub fn check_presupposition<Lbl: Label>(
Ok(_) => (), // expected
Err(_) => debug_assert!(false),
}
}
} else {
match solver.solve_with_assumptions(presupposition.situ.clone()) {
Ok(_) => solver.restore(DecLvl::from(presupposition.situ.len())),
Err(_) => {
return Err(UnmetPresupposition {
presupposition,
cause: UnmetPresuppositionCause::ModelSituUnsat,
})
}
}
}
}

// Remember, `situ` is already assumed (we backtracked to the latest assumption).
// !!! Remember, at this point `situ` is already assumed
// !!! so, we will just use `query` (or `query_neg`)
// !!! in `solve_with_assumptions` calls below (incremental solving).
debug_assert!(solver.current_decision_level() == DecLvl::from(presupposition.situ.len()));
// And so, we will just use `query` (or `query_neg`) in `solve_with_assumptions` calls below (incremental solving).

let res = match presupposition.kind {
PresuppositionKind::ModelSituUnsatWithQuery => {
match solver
.solve_with_assumptions(presupposition.query.clone())
.solve_with_assumptions(presupposition.query.iter().cloned())
.expect("Solver interrupted.")
{
Ok(_) => Err(UnmetPresupposition {
Expand All @@ -101,7 +117,7 @@ pub fn check_presupposition<Lbl: Label>(
}
PresuppositionKind::ModelSituSatWithQuery => {
match solver
.solve_with_assumptions(presupposition.query.clone())
.solve_with_assumptions(presupposition.query.iter().cloned())
.expect("Solver interrupted.")
{
Ok(_) => Ok(()),
Expand All @@ -114,12 +130,12 @@ pub fn check_presupposition<Lbl: Label>(
PresuppositionKind::ModelSituNotEntailQuery => {
let dl = DecLvl::from(presupposition.query.len());
match solver
.solve_with_assumptions(presupposition.query.clone())
.solve_with_assumptions(presupposition.query.iter().cloned())
.expect("Solver interrupted.")
{
Ok(_) => {
solver.restore(dl);
let query_neg = presupposition.query.iter().map(|&l| !l).collect_vec();
let query_neg = presupposition.query.iter().map(|&l| !l);
match solver.solve_with_assumptions(query_neg).expect("Solver interrupted.") {
Ok(_) => Err(UnmetPresupposition {
presupposition,
Expand All @@ -135,8 +151,8 @@ pub fn check_presupposition<Lbl: Label>(
}
}
PresuppositionKind::ModelSituEntailQuery => {
let neg_query = presupposition.query.iter().map(|&l| !l).collect_vec();
match solver.solve_with_assumptions(neg_query).expect("Solver interrupted.") {
let query_neg = presupposition.query.iter().map(|&l| !l);
match solver.solve_with_assumptions(query_neg).expect("Solver interrupted.") {
Ok(_) => Ok(()),
Err(_) => Err(UnmetPresupposition {
presupposition,
Expand All @@ -145,17 +161,8 @@ pub fn check_presupposition<Lbl: Label>(
}
}
};
// necessary if the solver was a cached one (given as parameter), to ensure it can be safely reused somewhere else.
// necessary if the solver was a cached one (given as parameter),
// to ensure it can be safely reused somewhere else.
solver.reset();
res
}

fn create_solver<Lbl: Label>(model: Model<Lbl>) -> Solver<Lbl> {
let stn_config = StnConfig {
theory_propagation: TheoryPropagationLevel::Full,
..Default::default()
};
let mut solver = Solver::<Lbl>::new(model);
solver.reasoners.diff.config = stn_config;
solver
}
40 changes: 20 additions & 20 deletions explainability/src/explain/why/unsat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use aries::core::Lit;
use aries::model::{Label, Model};

use crate::explain::explanation::{
EssenceIndex, ExplEssence, ExplSubstance, Explanation, ExplanationFilter, ModelIndex, SubstanceIndex,
EssenceIndex, Essence, Substance, Explanation, ExplanationFilter, ModelIndex, SubstanceIndex,
};
use crate::explain::presupposition::{check_presupposition, Presupposition, PresuppositionKind, UnmetPresupposition};
use crate::explain::{Query, Question, Situation, Vocab};
Expand Down Expand Up @@ -51,8 +51,8 @@ impl<Lbl: Label> Question<Lbl> for QwhyUnsat<Lbl> {
);
let muses = simple_marco.run().muses.unwrap();

let mut essences = Vec::<ExplEssence>::new();
let mut substances = Vec::<ExplSubstance>::new();
let mut essences = Vec::<Essence>::new();
let mut substances = Vec::<Substance>::new();
let mut table = BTreeMap::<(EssenceIndex, SubstanceIndex), BTreeSet<ModelIndex>>::new();
let filter = ExplanationFilter {
map: None,
Expand All @@ -62,7 +62,7 @@ impl<Lbl: Label> Question<Lbl> for QwhyUnsat<Lbl> {
let _situ_set = BTreeSet::from_iter(self.situ.iter().cloned());

for (mus_idx, mus) in muses.into_iter().enumerate() {
essences.push(ExplEssence(
essences.push(Essence(
mus.difference(&_situ_set).cloned().collect::<BTreeSet<Lit>>(),
mus.intersection(&_situ_set).cloned().collect::<BTreeSet<Lit>>(),
));
Expand All @@ -78,7 +78,7 @@ impl<Lbl: Label> Question<Lbl> for QwhyUnsat<Lbl> {
);
let mcses = simple_marco.run().mcses.unwrap();
for mcs in mcses {
let sub = ExplSubstance::Modelling(mcs);
let sub = Substance::ModelConstraints(mcs);
let sub_idx = substances.iter().position(|s| s == &sub);
match sub_idx {
Some(i) => table.insert((mus_idx, i), BTreeSet::from_iter([0])),
Expand Down Expand Up @@ -110,7 +110,7 @@ mod tests {
use aries::model::lang::expr::{and, implies};
use aries::model::lang::linear::LinearSum;

use crate::explain::explanation::{ExplEssence, ExplSubstance};
use crate::explain::explanation::{Essence, Substance};

use super::Question;

Expand Down Expand Up @@ -167,32 +167,32 @@ mod tests {

let expl = question.try_answer().unwrap();

let essences: HashSet<ExplEssence> = expl.essences.into_iter().collect::<HashSet<_>>();
let essences: HashSet<Essence> = expl.essences.iter().cloned().collect::<HashSet<_>>();
debug_assert_eq!(
essences,
HashSet::from_iter([
ExplEssence(BTreeSet::from_iter([p_a, p_b]), BTreeSet::from_iter([p_d])),
ExplEssence(BTreeSet::from_iter([p_b, p_c]), BTreeSet::from_iter([p_d])),
Essence(BTreeSet::from_iter([p_a, p_b]), BTreeSet::from_iter([p_d])),
Essence(BTreeSet::from_iter([p_b, p_c]), BTreeSet::from_iter([p_d])),
]),
);

let substances = expl.substances.into_iter().collect::<HashSet<_>>();
let substances = expl.substances.iter().cloned().collect::<HashSet<_>>();
debug_assert_eq!(
substances,
HashSet::from_iter([
ExplSubstance::Modelling(BTreeSet::from_iter([voc[0]])),
ExplSubstance::Modelling(BTreeSet::from_iter([voc[1]])),
ExplSubstance::Modelling(BTreeSet::from_iter([voc[2]])),
ExplSubstance::Modelling(BTreeSet::from_iter([voc[4]])),
Substance::ModelConstraints(BTreeSet::from_iter([voc[0]])),
Substance::ModelConstraints(BTreeSet::from_iter([voc[1]])),
Substance::ModelConstraints(BTreeSet::from_iter([voc[2]])),
Substance::ModelConstraints(BTreeSet::from_iter([voc[4]])),
]),
);

let idxe0 = essences.iter().position(|e| *e == ExplEssence(BTreeSet::from_iter([p_a, p_b]), BTreeSet::from_iter([p_d]))).unwrap();
let idxe1 = essences.iter().position(|e| *e == ExplEssence(BTreeSet::from_iter([p_b, p_c]), BTreeSet::from_iter([p_d]))).unwrap();
let idxs0 = substances.iter().position(|s| *s == ExplSubstance::Modelling(BTreeSet::from_iter([voc[0]]))).unwrap();
let idxs1 = substances.iter().position(|s| *s == ExplSubstance::Modelling(BTreeSet::from_iter([voc[1]]))).unwrap();
let idxs2 = substances.iter().position(|s| *s == ExplSubstance::Modelling(BTreeSet::from_iter([voc[2]]))).unwrap();
let idxs3 = substances.iter().position(|s| *s == ExplSubstance::Modelling(BTreeSet::from_iter([voc[4]]))).unwrap();
let idxe0 = expl.essences.iter().position(|e| *e == Essence(BTreeSet::from_iter([p_a, p_b]), BTreeSet::from_iter([p_d]))).unwrap();
let idxe1 = expl.essences.iter().position(|e| *e == Essence(BTreeSet::from_iter([p_b, p_c]), BTreeSet::from_iter([p_d]))).unwrap();
let idxs0 = expl.substances.iter().position(|s| *s == Substance::ModelConstraints(BTreeSet::from_iter([voc[0]]))).unwrap();
let idxs1 = expl.substances.iter().position(|s| *s == Substance::ModelConstraints(BTreeSet::from_iter([voc[1]]))).unwrap();
let idxs2 = expl.substances.iter().position(|s| *s == Substance::ModelConstraints(BTreeSet::from_iter([voc[2]]))).unwrap();
let idxs3 = expl.substances.iter().position(|s| *s == Substance::ModelConstraints(BTreeSet::from_iter([voc[4]]))).unwrap();

let table = expl.table;
debug_assert_eq!(
Expand Down

0 comments on commit 6af6bd8

Please sign in to comment.