Skip to content

Commit

Permalink
[planning] Give access to intermediate solutions in parallel solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
arbimo committed Sep 29, 2022
1 parent e645fa6 commit cccb43c
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 43 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion examples/jobshop/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,11 @@ fn main() {
};
let mut solver = get_solver(solver, opt.search, est_brancher);

let result = solver.minimize(makespan).unwrap();
let result = solver
.minimize_with(makespan, |assignment| {
println!("New solution with makespan: {}", assignment.var_domain(makespan).lb)
})
.unwrap();

if let Some((optimum, solution)) = result {
println!("Found optimal solution with makespan: {}", optimum);
Expand Down
2 changes: 1 addition & 1 deletion model/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
anyhow = { default-features = false, version = "1.0.35" }
anyhow = { version = "1.0.35" }
streaming-iterator = "0.1.5"
aries_core = { path = "../core" }
aries_backtrack = { path = "../backtrack" }
Expand Down
1 change: 1 addition & 0 deletions solver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ aries_collections = { path = "../collections" }
aries_backtrack = { path = "../backtrack" }
aries_core = { path = "../core" }
aries_model = { path = "../model" }
crossbeam-channel = "0.5"
env_param = { path = "../env_param" }
itertools = { default-features = false, version = "0.10.0" }
num-traits = { default-features = false, version = "0.2.14" }
Expand Down
137 changes: 99 additions & 38 deletions solver/src/parallel_solver.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::signals::{InputSignal, InputStream, OutputSignal, SolverOutput, ThreadID};
use crate::solver::{Exit, Solver};
use aries_core::IntCst;
use aries_model::extensions::{SavedAssignment, Shaped};
use aries_model::extensions::{AssignmentExt, SavedAssignment, Shaped};
use aries_model::lang::IAtom;
use aries_model::{Label, ModelShape};
use std::sync::mpsc::{channel, Receiver, Sender};
use crossbeam_channel::{select, Receiver, Sender};
use std::sync::Arc;
use std::thread;

Expand Down Expand Up @@ -66,8 +66,10 @@ impl<Lbl: Label> ParSolver<Lbl> {
}

/// Sets the output of all solvers to a particular channel and return its receiving end.
///
/// Assumes that no worker is currently running.
fn plug_solvers_output(&mut self) -> Receiver<SolverOutput> {
let (snd, rcv) = std::sync::mpsc::channel();
let (snd, rcv) = crossbeam_channel::unbounded();
for x in &mut self.solvers {
if let Worker::Idle(solver) = x {
solver.set_solver_output(snd.clone());
Expand All @@ -80,27 +82,60 @@ impl<Lbl: Label> ParSolver<Lbl> {

/// Solve the problem that was given on initialization using all available solvers.
pub fn solve(&mut self) -> Result<Option<Arc<SavedAssignment>>, Exit> {
self.race_solvers(|s| s.solve())
self.race_solvers(|s| s.solve(), |_| {})
}

/// Minimize the value of the given expression.
pub fn minimize(&mut self, objective: impl Into<IAtom>) -> Result<Option<(IntCst, Arc<SavedAssignment>)>, Exit> {
let objective = objective.into();
self.race_solvers(move |s| s.minimize(objective))
self.race_solvers(move |s| s.minimize(objective), |_| {})
}

/// Minimize the value of the given expression.
/// Each time a new solution is found with an improved objective value, the corresponding
/// assignment will be passed to the given callback.
pub fn minimize_with(
&mut self,
objective: impl Into<IAtom>,
on_improved_solution: impl Fn(Arc<SavedAssignment>),
) -> Result<Option<(IntCst, Arc<SavedAssignment>)>, Exit> {
let objective = objective.into();
// cost of the best solution found so far
let mut previous_best = None;

// callback that checks if a new solution is a strict improvement over the previous one
// and if that the case, invokes the user-provided callback
let on_new_sol = |ass: Arc<SavedAssignment>| {
let obj_value = ass.var_domain(objective).lb;
let is_improvement = match previous_best {
Some(prev) => prev > obj_value,
None => true,
};
if is_improvement {
on_improved_solution(ass);
previous_best = Some(obj_value)
}
};
self.race_solvers(move |s| s.minimize(objective), on_new_sol)
}

/// Generic function to run a lambda in parallel on all available solvers and return the result of the
/// first finishing one.
///
/// This function also setups inter-solver communication to enable clause/solution sharing.
/// Once a first result is found, it sends an interruption message to all other workers and wait for them to yield.
fn race_solvers<O, F>(&mut self, run: F) -> Result<O, Exit>
fn race_solvers<O, F, G>(&mut self, run: F, mut on_new_sol: G) -> Result<O, Exit>
where
O: Send + 'static,
F: Fn(&mut Solver<Lbl>) -> Result<O, Exit> + Send + 'static + Copy,
G: FnMut(Arc<SavedAssignment>),
{
// a receiver that will collect all intermediates results (incumbent solution and learned clauses)
// from the solvers
let solvers_output = self.plug_solvers_output();
let (result_snd, result_rcv) = channel();

// channel that is used to get the final results of the solvers.
let (result_snd, result_rcv) = crossbeam_channel::unbounded();

// lambda used to start a thread and run a solver on it.
let spawn = |id: usize, mut solver: Box<Solver<Lbl>>, result_snd: Sender<WorkerResult<O, Lbl>>| {
Expand All @@ -120,49 +155,68 @@ impl<Lbl: Label> ParSolver<Lbl> {
spawn(i, solver, result_snd.clone());
}

// start a new thread whose role is to send learnt clauses to other solvers
thread::spawn(move || {
while let Ok(x) = solvers_output.recv() {
// resend message to all other solvers. Note that a solver might have exited already
// and thus would not be able to receive the message
match x.msg {
OutputSignal::LearntClause(cl) => {
for input in &solvers_inputs {
if input.id != x.emitter {
let _ = input.sender.send(InputSignal::LearnedClause(cl.clone()));
let mut status = SolverStatus::Pending;

while self.is_worker_running() {
select! {
recv(result_rcv) -> res => {
let WorkerResult {
id: worker_id,
output: result,
solver,
} = res.unwrap();
self.solvers[worker_id] = Worker::Idle(solver);
if !matches!(status, SolverStatus::Final(_)) {
// this is the first result we got, store it and stop other solvers
status = SolverStatus::Final(result);
for s in &self.solvers {
if let Worker::Running(input) = s {
input.sender.send(InputSignal::Interrupt).unwrap();
}
}
}
OutputSignal::SolutionFound(assignment) => {
for input in &solvers_inputs {
if input.id != x.emitter {
let _ = input.sender.send(InputSignal::SolutionFound(assignment.clone()));
}
recv(solvers_output) -> msg => {
if let Ok(msg) = msg {
self.share_among_solvers(&msg);
if !matches!(status, SolverStatus::Final(_)) {
if let OutputSignal::SolutionFound(assignment) = msg.msg {
on_new_sol(assignment)
}
}
}
}
}
});

let WorkerResult {
id: first_id,
output: first_result,
solver,
} = result_rcv.recv().unwrap();
self.solvers[first_id] = Worker::Idle(solver);

for s in &self.solvers {
if let Worker::Running(input) = s {
input.sender.send(InputSignal::Interrupt).unwrap();
}
}

for _ in 0..(self.solvers.len() - 1) {
let result = result_rcv.recv().unwrap();
self.solvers[result.id] = Worker::Idle(result.solver);
match status {
SolverStatus::Final(res) => res,
_ => unreachable!(),
}
}

first_result
/// Returns true if there is at least one worker that is currently running.
fn is_worker_running(&self) -> bool {
self.solvers.iter().any(|solver| matches!(&solver, Worker::Running(_)))
}

/// Share an intermediate result with other running solvers that might be interested.
fn share_among_solvers(&self, signal: &SolverOutput) {
// resend message to all other solvers. Note that a solver might have exited already
// and thus would not be able to receive the message
for solver in &self.solvers {
match solver {
Worker::Running(input) if input.id != signal.emitter => match &signal.msg {
OutputSignal::LearntClause(cl) => {
let _ = input.sender.send(InputSignal::LearnedClause(cl.clone()));
}
OutputSignal::SolutionFound(assignment) => {
let _ = input.sender.send(InputSignal::SolutionFound(assignment.clone()));
}
},
_ => { /* Solver is not running or is the emitter, ignore */ }
}
}
}

/// Prints the statistics of all solvers.
Expand All @@ -183,3 +237,10 @@ impl<Lbl: Label> Shaped<Lbl> for ParSolver<Lbl> {
&self.base_model
}
}

enum SolverStatus<O> {
/// Still waiting for a final result.
Pending,
/// A final result was provided by at least one solver.
Final(Result<O, Exit>),
}
4 changes: 2 additions & 2 deletions solver/src/signals.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use aries_core::literals::Disjunction;
use aries_model::extensions::SavedAssignment;
use crossbeam_channel::{Receiver, Sender};
use env_param::EnvParam;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::{Receiver, Sender};
use std::sync::Arc;

/// The maximum size of a clause that can be shared with other threads.
Expand Down Expand Up @@ -55,7 +55,7 @@ pub struct Synchro {

impl Synchro {
pub fn new() -> Self {
let (snd, rcv) = std::sync::mpsc::channel();
let (snd, rcv) = crossbeam_channel::unbounded();
Synchro {
id: get_next_thread_id(),
sender: snd,
Expand Down
2 changes: 1 addition & 1 deletion solver/src/solver.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crossbeam_channel::Sender;
use itertools::Itertools;
use std::fmt::Formatter;
use std::sync::mpsc::Sender;
use std::sync::Arc;
use std::time::Instant;

Expand Down

0 comments on commit cccb43c

Please sign in to comment.