Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement BridgeStan download and module compilation on Rust #212

Merged
merged 32 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5c4808e
implement BridgeStan download and module compilation on Rust
randommm Feb 15, 2024
6132585
rust: add feature compile-stan-model
randommm Feb 17, 2024
3680a39
rust: allow user defined stanc_args and make_args when compiling model
randommm Feb 17, 2024
c9c3e8c
rust: add model_compiling test
randommm Feb 17, 2024
d9f1cd7
rust: updating readme
randommm Feb 17, 2024
e9970cd
rust: update documentation
randommm Feb 17, 2024
c06599c
rust: fix tests
randommm Feb 18, 2024
065644b
rust: skip model_compiling() test on windows
randommm Feb 18, 2024
30c814e
rust: fix race condition in tests
randommm Feb 18, 2024
c9d1d95
rust: make example path portable
randommm Feb 18, 2024
6fb0bf1
rust: fix windows absolute path resolution
randommm Feb 18, 2024
a660a81
Delete rust/.vscode/settings.json
randommm Feb 28, 2024
61c4860
rust: mark model_compiling test as ignored
randommm Feb 28, 2024
2035556
rust: use mingw32-make to compile model on windows
randommm Feb 28, 2024
0c48c02
rust: change println! to info!
randommm Feb 29, 2024
54f9ecd
Update README.md
randommm Feb 29, 2024
2a8db36
Update Cargo.toml
randommm Mar 5, 2024
86e9ee8
rust: single compile error message
randommm Mar 5, 2024
31ae1d1
rust: run tests without feature compile-stan-model
randommm Mar 5, 2024
3cab3e4
rust: adding comments about std::fs::canonicalize
randommm Mar 7, 2024
ec957f7
rust: fix --include-paths to point to model dir
randommm Mar 7, 2024
c9517c8
rust: disable enum variant feature gating
randommm Mar 7, 2024
8a64472
rust: fix macos build
randommm Mar 7, 2024
d6851b8
rust: make bridgestan src download more explicit
randommm Mar 9, 2024
837fc04
rust: only bridgestan_download_src is to be feature gated
randommm Mar 23, 2024
1948a89
unify .gitignore
randommm Mar 23, 2024
fb29536
test improvements
randommm Apr 6, 2024
08a9215
remove asref generic
randommm Apr 6, 2024
3dd1911
Merge remote-tracking branch 'origin/main'
randommm Apr 6, 2024
3de69b9
fix tests
randommm Apr 6, 2024
39f9108
Update model.rs
randommm Apr 13, 2024
f836e89
Clean up Rust doc, tests
WardBrian May 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ homepage = "https://roualdes.github.io/bridgestan/latest/"
[dependencies]
libloading = "0.8.0"
thiserror = "1.0.40"
ureq = "2.7"
tar = "0.4"
flate2 = "1.0"
dirs = "5.0"

[build-dependencies]
bindgen = "0.69.1"
Expand Down
7 changes: 5 additions & 2 deletions rust/examples/example.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use bridgestan::{open_library, BridgeStanError, Model};
use bridgestan::{compile_model, open_library, BridgeStanError, Model};
use std::ffi::CString;
use std::path::Path;

Expand All @@ -8,7 +8,10 @@ fn main() {
let path = Path::new(env!["CARGO_MANIFEST_DIR"])
.parent()
.unwrap()
.join("test_models/simple/simple_model.so");
.join("test_models/simple/simple.stan");

let path = compile_model(path, None).expect("Could not compile Stan model.");
println!("Compiled model: {:?}", path);

let lib = open_library(path).expect("Could not load compiled Stan model.");

Expand Down
9 changes: 8 additions & 1 deletion rust/src/bs_safe.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::download_compile::VERSION;
use crate::ffi;
use std::borrow::Borrow;
use std::collections::hash_map::DefaultHasher;
Expand Down Expand Up @@ -101,9 +102,15 @@ pub enum BridgeStanError {
/// Setting a print-callback failed.
#[error("Failed to set a print-callback: {0}")]
SetCallbackFailed(String),
/// Setting a compile Stan model failed.
#[error("Failed to compile Stan model: {0}")]
ModelCompilingFailed(String),
/// Setting a download BridgeStan failed.
#[error("Failed to download BridgeStan {VERSION} from github.com: {0}")]
DownloadFailed(String),
}

type Result<T> = std::result::Result<T, BridgeStanError>;
pub(crate) type Result<T> = std::result::Result<T, BridgeStanError>;

/// Open a compiled Stan library.
///
Expand Down
108 changes: 108 additions & 0 deletions rust/src/download_compile.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use crate::bs_safe::{BridgeStanError, Result};
use flate2::read::GzDecoder;
use std::{env::temp_dir, fs, path::PathBuf};
use tar::Archive;

pub(crate) const VERSION: &str = env!("CARGO_PKG_VERSION");

/// Download and unzip the BridgeStan source distribution for this version
/// to ~/.bridgestan/bridgestan-version
pub fn get_bridgestan_src() -> Result<PathBuf> {
let homedir = dirs::home_dir().unwrap_or(temp_dir());

let bs_path_download_temp = homedir.join(".bridgestan_tmp_dir");
let bs_path_download = homedir.join(".bridgestan");

let bs_path_download_temp_join_version =
bs_path_download_temp.join(format!("bridgestan-{VERSION}"));
let bs_path_download_join_version = bs_path_download.join(format!("bridgestan-{VERSION}"));

if !bs_path_download_join_version.exists() {
println!("Downloading BridgeStan");

fs::remove_dir_all(&bs_path_download_temp).unwrap_or_default();
fs::create_dir(&bs_path_download_temp).unwrap_or_default();
fs::create_dir(&bs_path_download).unwrap_or_default();

let url = "https://github.com/roualdes/bridgestan/releases/download/".to_owned()
+ format!("v{VERSION}/bridgestan-{VERSION}.tar.gz").as_str();

let response = ureq::get(url.as_str())
.call()
.map_err(|e| BridgeStanError::DownloadFailed(e.to_string()))?;
let len = response
.header("Content-Length")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(50_000_000);

let mut bytes: Vec<u8> = Vec::with_capacity(len);
response
.into_reader()
.read_to_end(&mut bytes)
.map_err(|e| BridgeStanError::DownloadFailed(e.to_string()))?;

let tar = GzDecoder::new(bytes.as_slice());
let mut archive = Archive::new(tar);
archive
.unpack(&bs_path_download_temp)
.map_err(|e| BridgeStanError::DownloadFailed(e.to_string()))?;

fs::rename(
bs_path_download_temp_join_version,
&bs_path_download_join_version,
)
.map_err(|e| BridgeStanError::DownloadFailed(e.to_string()))?;

fs::remove_dir(bs_path_download_temp).unwrap_or_default();

println!("Finished downloading BridgeStan");
}

Ok(bs_path_download_join_version)
}

/// Compile a Stan Model given a stan_file and the path to BridgeStan
/// if None, then calls get_bridgestan_src() to download BridgeStan
pub fn compile_model(stan_file: PathBuf, bs_path: Option<PathBuf>) -> Result<PathBuf> {
let bs_path = match bs_path {
Some(path) => path,
None => get_bridgestan_src()?,
};

let stan_file = fs::canonicalize(stan_file)
.map_err(|e| BridgeStanError::ModelCompilingFailed(e.to_string()))?;

if stan_file.extension().unwrap_or_default() != "stan" {
return Err(BridgeStanError::ModelCompilingFailed(
"File must be a .stan file".to_owned(),
));
}

// add _model suffix and change extension to .so
let output = stan_file.with_extension("");
let output = output.with_file_name(format!(
"{}_model",
output.file_name().unwrap_or_default().to_string_lossy()
));
let output = output.with_extension("so");

let cmd = vec![output.to_str().unwrap_or_default().to_owned()];

println!("Compiling model");
let proc = std::process::Command::new("make")
.args(cmd)
.current_dir(bs_path)
.env("STAN_THREADS", "true")
.output()
.map_err(|e| BridgeStanError::ModelCompilingFailed(e.to_string()))?;
println!("Finished compiling model");

if !proc.status.success() {
return Err(BridgeStanError::ModelCompilingFailed(format!(
"{} {}",
String::from_utf8_lossy(proc.stdout.as_slice()).into_owned(),
String::from_utf8_lossy(proc.stderr.as_slice()).into_owned(),
)));
}
Ok(output)
}
2 changes: 2 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#![doc = include_str!("../README.md")]

mod bs_safe;
mod download_compile;
pub(crate) mod ffi;

pub use bs_safe::{open_library, BridgeStanError, Model, Rng, StanLibrary};
pub use download_compile::{compile_model, get_bridgestan_src};
12 changes: 6 additions & 6 deletions rust/tests/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ fn throw_data() {
fn bad_arglength() {
let (lib, data) = get_model("stdnormal");
let model = Model::new(&lib, data, 42).unwrap();
let theta = vec![];
let mut grad = vec![];
let theta = [];
let mut grad = [];
let _ = model.log_density_gradient(&theta[..], true, true, &mut grad[..]);
}

#[test]
fn logp_gradient() {
let (lib, data) = get_model("stdnormal");
let model = Model::new(&lib, data, 42).unwrap();
let theta = vec![1f64];
let mut grad = vec![0f64];
let theta = [1f64];
let mut grad = [0f64];
let logp = model
.log_density_gradient(&theta[..], false, true, &mut grad[..])
.unwrap();
Expand All @@ -47,8 +47,8 @@ fn logp_gradient() {
fn logp_hessian() {
let (lib, data) = get_model("stdnormal");
let model = Model::new(&lib, data, 42).unwrap();
let theta = vec![1f64];
let mut grad = vec![0f64];
let theta = [1f64];
let mut grad = [0f64];
let mut hessian = vec![0f64];
let logp = model
.log_density_hessian(&theta[..], false, true, &mut grad[..], &mut hessian)
Expand Down
Loading