Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add solution handler #31

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

cyconer
Copy link

@cyconer cyconer commented Aug 10, 2023

Partial of #9

This adds a solution handler that is called on newly encountered solutions (or improvements in optimization).
The search can not yet be stopped (more details below).

Parallelism

The reference implementation https://github.com/google/or-tools/blob/stable/ortools/sat/docs/solver.md?plain=1#L472 states

Please note that it does not work in parallel
(i. e. parameter num_search_workers > 1).

As mentioned in the ticket #9, it might be desirable to use Fn instead of FnMut. As I interpret the above that the search does not work in parallel anyway, it does not seem to matter whether we use Fn or FnMut. The implementation right now uses FnMut, but we can easily change that.

ControlFlow

Now it gets weird. I have played around with this based on the reference implementation, but something very strange seems to happen with the C++ part.

Diff of potential approach for Control Flow
diff --git a/src/builder.rs b/src/builder.rs
index 2fb392b..17d2a18 100644
--- a/src/builder.rs
+++ b/src/builder.rs
@@ -786,6 +786,7 @@ impl CpModelBuilder {
     /// # use std::rc::Rc;
     /// # use cp_sat::builder::CpModelBuilder;
     /// # use cp_sat::proto::{SatParameters, CpSolverResponse};
+    /// # use std::ops::ControlFlow;
     /// let mut model = CpModelBuilder::default();
     /// // linear constraint will only allow a = 2, a = 3 and a = 4
     /// let a = model.new_int_var([(2, 7)]);
@@ -793,10 +794,11 @@ impl CpModelBuilder {
     /// let mut params = SatParameters::default();
     /// params.enumerate_all_solutions = Some(true);
     ///
-    /// let memory = Rc::new(RefCell::new(Vec::new()));
+    /// let memory: Rc<RefCell<Vec<CpSolverResponse>>> = Rc::new(RefCell::new(Vec::new()));
     /// let memory2 = memory.clone();
     /// let handler = move |response: CpSolverResponse| {
     ///     memory2.borrow_mut().push(response);
+    ///     ControlFlow::Continue(())
     /// };
     ///
     /// let _response = model.solve_with_parameters_and_handler(&params, handler);
@@ -805,7 +807,7 @@ impl CpModelBuilder {
     pub fn solve_with_parameters_and_handler(
         &self,
         params: &proto::SatParameters,
-        handler: impl FnMut(proto::CpSolverResponse) + 'static,
+        handler: impl FnMut(proto::CpSolverResponse) -> std::ops::ControlFlow<()> + 'static,
     ) -> proto::CpSolverResponse {
         ffi::solve_with_parameters_and_handler(self.proto(), params, Box::new(handler))
     }
diff --git a/src/cp_sat_wrapper.cpp b/src/cp_sat_wrapper.cpp
index c2a1126..eb92d38 100644
--- a/src/cp_sat_wrapper.cpp
+++ b/src/cp_sat_wrapper.cpp
@@ -2,6 +2,7 @@
 
 #include <ortools/sat/cp_model.h>
 #include <ortools/sat/cp_model_checker.h>
+#include <ortools/util/time_limit.h>
 
 namespace sat = operations_research::sat;
 
@@ -58,8 +59,10 @@ cp_sat_wrapper_solve(
  * - serialized buffer of a CpSolverResponse
  * - length of the buffer
  * - additional data passed from the outside
+ *
+ * Returns true if the search should be aborted.
  */
-typedef void (*solution_handler)(unsigned char*, size_t, void*);
+typedef bool (*solution_handler)(unsigned char*, size_t, void*);
 
 /**
  * Similar to cp_sat_wrapper_solve_with_parameters, but with a callback function
@@ -89,6 +92,10 @@ cp_sat_wrapper_solve_with_parameters_and_handler(
 
     extra_model.Add(sat::NewSatParameters(params));
 
+    // Atomic Boolean that will be periodically checked by the limit.
+    std::atomic<bool> stopped(false);
+    extra_model.GetOrCreate<operations_research::TimeLimit>()->RegisterExternalBooleanAsLimit(&stopped);
+
     // local function that serializes the CpSolverResponse for the provided solution handler
     auto wrapped_handler = [&](const operations_research::sat::CpSolverResponse& curr_response) {
         // serialize CpSolverResponse
@@ -97,7 +104,10 @@ cp_sat_wrapper_solve_with_parameters_and_handler(
         bool curr_res = curr_response.SerializeToArray(response_buf, response_size);
         assert(curr_res);
 
-        handler(response_buf, response_size, handler_data);
+        bool abort = handler(response_buf, response_size, handler_data);
+        if (abort) {
+            stopped = true;
+        }
     };
     extra_model.Add(sat::NewFeasibleSolutionObserver(wrapped_handler));
 
diff --git a/src/ffi.rs b/src/ffi.rs
index 7284c6c..27c07c8 100644
--- a/src/ffi.rs
+++ b/src/ffi.rs
@@ -3,6 +3,7 @@ use libc::c_char;
 use prost::Message;
 use std::ffi::CStr;
 use std::ffi::c_void;
+use std::ops::ControlFlow;
 
 extern "C" {
     fn cp_sat_wrapper_solve(
@@ -22,7 +23,7 @@ extern "C" {
         model_size: usize,
         params_buf: *const u8,
         params_size: usize,
-        handler_caller: extern "C" fn(*const u8, usize, *mut c_void),
+        handler_caller: extern "C" fn(*const u8, usize, *mut c_void) -> bool,
         handler: *mut c_void,
         out_size: &mut usize,
     ) -> *mut u8;
@@ -83,7 +84,8 @@ pub fn solve_with_parameters(
 }
 
 /// User provided solution handler that is called with feasible solutions.
-pub type SolutionHandler = Box<dyn FnMut(proto::CpSolverResponse)>;
+/// The control flow can be used to abort the search.
+pub type SolutionHandler = Box<dyn FnMut(proto::CpSolverResponse) -> ControlFlow<()>>;
 
 /// Solves the given [CpModelProto][crate::proto::CpModelProto] with
 /// the given parameters,
@@ -129,16 +131,23 @@ pub fn solve_with_parameters_and_handler(
 /// - `response_buf` and `response_size`: buffer and size of a [proto::CpSolverResponse]
 /// - `handler`: a user provided solution handler [SolutionHandler] that accepts a
 ///     [proto::CpSolverResponse]
-extern "C" fn solution_handler_caller(response_buf: *const u8, response_size: usize, handler: *mut c_void) {
+///
+/// Returns `true` if the search should be aborted.
+extern "C" fn solution_handler_caller(response_buf: *const u8, response_size: usize, handler: *mut c_void) -> bool {
     let response_slice = unsafe {
         std::slice::from_raw_parts(response_buf, response_size)
     };
     let response = proto::CpSolverResponse::decode(response_slice).unwrap();
     unsafe { libc::free(response_buf as _) };
 
-    unsafe {
+    let control_flow = unsafe {
         let tmp = handler as *mut SolutionHandler;
-        (*tmp)(response);
+        (*tmp)(response)
+    };
+
+    match control_flow {
+        ControlFlow::Continue(_) => false,
+        ControlFlow::Break(_) => true,
     }
 }
 
diff --git a/tests/solution_handler.rs b/tests/solution_handler.rs
index 959541f..3caf620 100644
--- a/tests/solution_handler.rs
+++ b/tests/solution_handler.rs
@@ -14,10 +14,11 @@ fn enumeration_solution_handler() {
   let mut params = SatParameters::default();
   params.enumerate_all_solutions = Some(true);
 
-  let memory = Rc::new(RefCell::new(Vec::new()));
+  let memory: Rc<RefCell<Vec<CpSolverResponse>>> = Rc::new(RefCell::new(Vec::new()));
   let memory2 = memory.clone();
   let handler = move |response: CpSolverResponse| {
     memory2.borrow_mut().push(response);
+    std::ops::ControlFlow::Continue(())
   };
 
   let _response = model.solve_with_parameters_and_handler(&params, handler);
@@ -45,10 +46,11 @@ fn optimization_solution_handler() {
   let mut params = SatParameters::default();
   params.enumerate_all_solutions = Some(true);
 
-  let memory = Rc::new(RefCell::new(Vec::new()));
+  let memory: Rc<RefCell<Vec<CpSolverResponse>>> = Rc::new(RefCell::new(Vec::new()));
   let memory2 = memory.clone();
   let handler = move |response: CpSolverResponse| {
     memory2.borrow_mut().push(response);
+    std::ops::ControlFlow::Continue(())
   };
 
   let response = model.solve_with_parameters_and_handler(&params, handler);
@@ -61,3 +63,31 @@ fn optimization_solution_handler() {
   // improvement.
   assert!(memory.borrow().len() >= 1);
 }
+
+/// It should be possible to stop the search from the callback.
+#[test]
+fn stop_solution_handler() {
+  let mut model = CpModelBuilder::default();
+  // linear constraint will only allow a = 2, a = 3 and a = 4
+  let a = model.new_int_var([(2, 7)]);
+  model.add_linear_constraint([(3, a)], [(0, 13)]);
+  let mut params = SatParameters::default();
+  params.enumerate_all_solutions = Some(true);
+
+  let memory: Rc<RefCell<Vec<CpSolverResponse>>> = Rc::new(RefCell::new(Vec::new()));
+  let memory2 = memory.clone();
+  let handler = move |response: CpSolverResponse| {
+    memory2.borrow_mut().push(response);
+
+    if memory2.borrow().len() < 2 {
+      std::ops::ControlFlow::Continue(())
+    } else {
+      std::ops::ControlFlow::Break(())
+    }
+  };
+
+  let _response = model.solve_with_parameters_and_handler(&params, handler);
+
+  // Instead of the 3 feasible solution the search was aborted after 2.
+  assert_eq!(2, memory.borrow().len());
+}

Problem: completely unrelated code starts to exibit SIGSEGV when extra_model.GetOrCreate<operations_research::TimeLimit>() is included.

Minimal failing example that completely baffles me
diff --git a/src/cp_sat_wrapper.cpp b/src/cp_sat_wrapper.cpp
index c2a1126..94c6661 100644
--- a/src/cp_sat_wrapper.cpp
+++ b/src/cp_sat_wrapper.cpp
@@ -2,6 +2,7 @@
 
 #include <ortools/sat/cp_model.h>
 #include <ortools/sat/cp_model_checker.h>
+#include <ortools/util/time_limit.h>
 
 namespace sat = operations_research::sat;
 
@@ -89,6 +90,12 @@ cp_sat_wrapper_solve_with_parameters_and_handler(
 
     extra_model.Add(sat::NewSatParameters(params));
 
+    bool this_is_never_reached = false;
+    if (this_is_never_reached) {
+        // Including this line leads to SIGSEGV of e.g. cp_sat_wrapper_solve???
+        extra_model.GetOrCreate<operations_research::TimeLimit>();
+    }
+
     // local function that serializes the CpSolverResponse for the provided solution handler
     auto wrapped_handler = [&](const operations_research::sat::CpSolverResponse& curr_response) {
         // serialize CpSolverResponse

E.g. the tests/bool_cst.rs starts to exhibit a SIGSEGV which does not use the new cp_sat_wrapper_solve_with_parameters_and_handler but the unchanged cp_sat_wrapper_solve. It seems to fail in the absl library in absl/container/internal/raw_hash_set.h in line 1562. It does not fail when using e.g. GetOrCreate<bool>(), but fails using TimeLimit as the generic.

I have no idea what is going on, how unrelated code can fail just because a (never used) function call is included in another function. Maybe some C++-magic of generics that affects the Model globally... Any ideas are appreciated.

@cyconer cyconer mentioned this pull request Aug 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant