Skip to content

Commit

Permalink
Save/Load Gp/Sgp surrogates in binary format (#213)
Browse files Browse the repository at this point in the history
* Implement save as binary file

* Enable feature rand_xoshiro/serde1 required for serializable

* Fix json save

* Refactor

* Add py test save binary

* Cleanup

* Ignore binary files

* Update notebook with binary save
  • Loading branch information
relf authored Nov 12, 2024
1 parent dce67e2 commit bda92f9
Show file tree
Hide file tree
Showing 13 changed files with 1,209 additions and 1,118 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# Project
**/*.json
**/.bin
*.npy
input.txt
output.txt
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2,117 changes: 1,066 additions & 1,051 deletions doc/Gpx_Tutorial.ipynb

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions ego/src/gpmix/mixint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use egobox_doe::{FullFactorial, Lhs, LhsKind, Random};
use egobox_gp::metrics::CrossValScore;
use egobox_gp::ThetaTuning;
use egobox_moe::{
Clustered, Clustering, CorrelationSpec, FullGpSurrogate, GpMixture, GpMixtureParams,
GpSurrogate, GpSurrogateExt, MixtureGpSurrogate, RegressionSpec,
Clustered, Clustering, CorrelationSpec, FullGpSurrogate, GpFileFormat, GpMixture,
GpMixtureParams, GpSurrogate, GpSurrogateExt, MixtureGpSurrogate, RegressionSpec,
};
use linfa::traits::{Fit, PredictInplace};
use linfa::{DatasetBase, Float, ParamGuard};
Expand Down Expand Up @@ -581,13 +581,15 @@ impl GpSurrogate for MixintGpMixture {

/// Save Moe model in given file.
#[cfg(feature = "persistent")]
fn save(&self, path: &str) -> egobox_moe::Result<()> {
fn save(&self, path: &str, format: GpFileFormat) -> egobox_moe::Result<()> {
use egobox_moe::GpFileFormat;

let mut file = fs::File::create(path).unwrap();
let bytes = match serde_json::to_string(self) {
Ok(b) => b,
Err(err) => return Err(MoeError::SaveError(err)),
let bytes = match format {
GpFileFormat::Json => serde_json::to_vec(self).map_err(MoeError::SaveJsonError)?,
GpFileFormat::Binary => bincode::serialize(self).map_err(MoeError::SaveBinaryError)?,
};
file.write_all(bytes.as_bytes())?;
file.write_all(&bytes)?;
Ok(())
}
}
Expand Down
2 changes: 1 addition & 1 deletion gp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ categories = ["algorithms", "mathematics", "science"]

default = []

serializable = ["serde", "typetag", "linfa/serde"]
serializable = ["serde", "typetag", "linfa/serde", "rand_xoshiro/serde1"]
persistent = ["serializable", "serde_json"]
blas = ["ndarray-linalg", "linfa/ndarray-linalg", "linfa-pls/blas"]

Expand Down
4 changes: 3 additions & 1 deletion moe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ categories = ["algorithms", "mathematics", "science"]
[features]
default = []

persistent = ["serializable", "serde_json"]
persistent = ["serializable", "serde_json", "bincode"]
serializable = [
"serde",
"typetag",
Expand Down Expand Up @@ -51,6 +51,8 @@ thiserror = "1"

serde = { version = "1", features = ["derive"], optional = true }
serde_json = { version = "1", optional = true }
bincode = { version = "1.3.3", optional = true }

typetag = { version = "0.2", optional = true }

[dev-dependencies]
Expand Down
29 changes: 19 additions & 10 deletions moe/src/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,13 +483,15 @@ impl GpSurrogate for GpMixture {
}
/// Save Moe model in given file.
#[cfg(feature = "persistent")]
fn save(&self, path: &str) -> Result<()> {
fn save(&self, path: &str, format: GpFileFormat) -> Result<()> {
let mut file = fs::File::create(path).unwrap();
let bytes = match serde_json::to_string(self) {
Ok(b) => b,
Err(err) => return Err(MoeError::SaveError(err)),

let bytes = match format {
GpFileFormat::Json => serde_json::to_vec(self).map_err(MoeError::SaveJsonError)?,
GpFileFormat::Binary => bincode::serialize(self).map_err(MoeError::SaveBinaryError)?,
};
file.write_all(bytes.as_bytes())?;
file.write_all(&bytes)?;

Ok(())
}
}
Expand Down Expand Up @@ -878,9 +880,12 @@ impl GpMixture {

#[cfg(feature = "persistent")]
/// Load Moe from given json file.
pub fn load(path: &str) -> Result<Box<GpMixture>> {
let data = fs::read_to_string(path)?;
let moe: GpMixture = serde_json::from_str(&data).unwrap();
pub fn load(path: &str, format: GpFileFormat) -> Result<Box<GpMixture>> {
let data = fs::read(path)?;
let moe = match format {
GpFileFormat::Json => serde_json::from_slice(&data).unwrap(),
GpFileFormat::Binary => bincode::deserialize(&data).unwrap(),
};
Ok(Box::new(moe))
}
}
Expand Down Expand Up @@ -1150,8 +1155,8 @@ mod tests {
let xtest = array![[0.6]];
let y_expected = moe.predict(&xtest).unwrap();
let filename = format!("{test_dir}/saved_moe.json");
moe.save(&filename).expect("MoE saving");
let new_moe = GpMixture::load(&filename).expect("MoE loading");
moe.save(&filename, GpFileFormat::Json).expect("MoE saving");
let new_moe = GpMixture::load(&filename, GpFileFormat::Json).expect("MoE loading");
assert_abs_diff_eq!(y_expected, new_moe.predict(&xtest).unwrap(), epsilon = 1e-6);
}

Expand Down Expand Up @@ -1358,6 +1363,10 @@ mod tests {
.fit(&Dataset::new(xt, yt))
.expect("GP fit error");

// To see file size : 100D => json ~ 1.2Mo, bin ~ 0.6Mo
// gp.save("griewank.json", GpFileFormat::Json).unwrap();
// gp.save("griewank.bin", GpFileFormat::Binary).unwrap();

let rng = Xoshiro256Plus::seed_from_u64(0);
let xtest = Lhs::new(&xlimits).with_rng(rng).sample(100);
let ytest = gp.predict(&xtest).expect("prediction error");
Expand Down
6 changes: 5 additions & 1 deletion moe/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ pub enum MoeError {
/// When error during saving
#[cfg(feature = "persistent")]
#[error("Save error: {0}")]
SaveError(#[from] serde_json::Error),
SaveJsonError(#[from] serde_json::Error),
/// When error during saving
#[cfg(feature = "persistent")]
#[error("Save error: {0}")]
SaveBinaryError(#[from] bincode::Error),
/// When error during loading
#[error("Load IO error")]
LoadIoError(#[from] std::io::Error),
Expand Down
52 changes: 34 additions & 18 deletions moe/src/surrogates.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::errors::Result;
use crate::types::GpFileFormat;
use egobox_gp::{
correlation_models::*, mean_models::*, GaussianProcess, GpParams, SgpParams,
SparseGaussianProcess, SparseMethod, ThetaTuning,
Expand Down Expand Up @@ -54,7 +55,7 @@ pub trait GpSurrogate: std::fmt::Display + Sync + Send {
fn predict_var(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>>;
/// Save model in given file.
#[cfg(feature = "persistent")]
fn save(&self, path: &str) -> Result<()>;
fn save(&self, path: &str, format: GpFileFormat) -> Result<()>;
}

/// A trait for a GP surrogate with derivatives predictions and sampling
Expand Down Expand Up @@ -157,13 +158,17 @@ macro_rules! declare_surrogate {
}

#[cfg(feature = "persistent")]
fn save(&self, path: &str) -> Result<()> {
fn save(&self, path: &str, format: GpFileFormat) -> Result<()> {
let mut file = fs::File::create(path).unwrap();
let bytes = match serde_json::to_string(self as &dyn GpSurrogate) {
Ok(b) => b,
Err(err) => return Err(MoeError::SaveError(err))
let bytes = match format {
GpFileFormat::Json => serde_json::to_vec(self as &dyn GpSurrogate)
.map_err(MoeError::SaveJsonError)?,
GpFileFormat::Binary => {
bincode::serialize(self as &dyn GpSurrogate).map_err(MoeError::SaveBinaryError)?
}
};
file.write_all(bytes.as_bytes())?;
file.write_all(&bytes)?;

Ok(())
}

Expand Down Expand Up @@ -311,13 +316,16 @@ macro_rules! declare_sgp_surrogate {
}

#[cfg(feature = "persistent")]
fn save(&self, path: &str) -> Result<()> {
fn save(&self, path: &str, format: GpFileFormat) -> Result<()> {
let mut file = fs::File::create(path).unwrap();
let bytes = match serde_json::to_string(self as &dyn SgpSurrogate) {
Ok(b) => b,
Err(err) => return Err(MoeError::SaveError(err))
let bytes = match format {
GpFileFormat::Json => serde_json::to_vec(self as &dyn SgpSurrogate)
.map_err(MoeError::SaveJsonError)?,
GpFileFormat::Binary => {
bincode::serialize(self as &dyn SgpSurrogate).map_err(MoeError::SaveBinaryError)?
}
};
file.write_all(bytes.as_bytes())?;
file.write_all(&bytes)?;
Ok(())
}
}
Expand Down Expand Up @@ -382,10 +390,17 @@ declare_sgp_surrogate!(Matern52);

#[cfg(feature = "persistent")]
/// Load GP surrogate from given json file.
pub fn load(path: &str) -> Result<Box<dyn GpSurrogate>> {
let data = fs::read_to_string(path)?;
let gp: Box<dyn GpSurrogate> = serde_json::from_str(&data).unwrap();
Ok(gp)
pub fn load(path: &str, format: GpFileFormat) -> Result<Box<dyn GpSurrogate>> {
let data = fs::read(path)?;
match format {
GpFileFormat::Json => {
serde_json::from_slice::<Box<dyn GpSurrogate>>(&data).map_err(|err| {
MoeError::LoadError(format!("Error while loading from {path}: ({err})"))
})
}
GpFileFormat::Binary => bincode::deserialize(&data)
.map_err(|err| MoeError::LoadError(format!("Error while loading from {path} ({err})"))),
}
}

#[doc(hidden)]
Expand Down Expand Up @@ -448,8 +463,9 @@ mod tests {
let gp = make_surrogate_params!(Constant, SquaredExponential)
.train(&xt.view(), &yt.view())
.expect("GP fit error");
gp.save("target/tests/save_gp.json").expect("GP not saved");
let gp = load("target/tests/save_gp.json").expect("GP not loaded");
gp.save("target/tests/save_gp.json", GpFileFormat::Json)
.expect("GP not saved");
let gp = load("target/tests/save_gp.json", GpFileFormat::Json).expect("GP not loaded");
let xv = Lhs::new(&xlimits).sample(20);
let yv = xsinx(&xv);
let ytest = gp.predict(&xv.view()).unwrap();
Expand All @@ -459,7 +475,7 @@ mod tests {

#[test]
fn test_load_fail() {
let gp = load("notfound.json");
let gp = load("notfound.json", GpFileFormat::Json);
assert!(gp.is_err());
}
}
10 changes: 10 additions & 0 deletions moe/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,13 @@ impl Clustering {
pub trait MixtureGpSurrogate: Clustered + GpSurrogate + GpSurrogateExt {
fn experts(&self) -> &Vec<Box<dyn FullGpSurrogate>>;
}

#[derive(Default, Debug)]
/// An enumeration of Gpx available file format
pub enum GpFileFormat {
/// Human readable format
#[default]
Json,
/// Binary format
Binary,
}
35 changes: 19 additions & 16 deletions python/egobox/tests/test_gpmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,29 @@ def test_gpx_kriging(self):

def test_gpx_save_load(self):
filename = "gpdump.json"

filename_bin = "gpdump.bin"
gpx = self.gpx

if os.path.exists(filename):
os.remove(filename)
gpx.save(filename)
gpx2 = egx.Gpx.load(filename)
os.remove(filename)
for file in [filename, filename_bin]:
if os.path.exists(file):
os.remove(file)
gpx.save(file)

# should interpolate
self.assertAlmostEqual(1.0, gpx2.predict(np.array([[1.0]])).item())
self.assertAlmostEqual(0.0, gpx2.predict_var(np.array([[1.0]])).item())
gpx2 = egx.Gpx.load(file)

# check a point not too far from a training point
self.assertAlmostEqual(
1.1163, gpx2.predict(np.array([[1.1]])).item(), delta=1e-3
)
self.assertAlmostEqual(
0.0, gpx2.predict_var(np.array([[1.1]])).item(), delta=1e-3
)
os.remove(file)

# should interpolate
self.assertAlmostEqual(1.0, gpx2.predict(np.array([[1.0]])).item())
self.assertAlmostEqual(0.0, gpx2.predict_var(np.array([[1.0]])).item())

# check a point not too far from a training point
self.assertAlmostEqual(
1.1163, gpx2.predict(np.array([[1.1]])).item(), delta=1e-3
)
self.assertAlmostEqual(
0.0, gpx2.predict_var(np.array([[1.1]])).item(), delta=1e-3
)

def test_training_params(self):
self.assertEqual(self.gpx.dims(), (1, 1))
Expand Down
28 changes: 21 additions & 7 deletions src/gp_mix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
//!
//! See the [tutorial notebook](https://github.com/relf/egobox/doc/Gpx_Tutorial.ipynb) for usage.
//!
use std::path::Path;

use crate::types::*;
use egobox_gp::metrics::CrossValScore;
use egobox_moe::{Clustered, MixtureGpSurrogate, ThetaTuning};
Expand Down Expand Up @@ -236,25 +238,37 @@ impl Gpx {
self.0.to_string()
}

/// Save Gaussian processes mixture in a json file.
/// Save Gaussian processes mixture in a file.
/// If the filename has .json JSON human readable format is used
/// otherwise an optimized binary format is used.
///
/// Parameters
/// filename (string)
/// json file generated in the current directory
/// filename with .json or .bin extension (string)
/// file generated in the current directory
///
/// Returns True if save succeeds otherwise False
///
fn save(&self, filename: String) {
self.0.save(&filename).ok();
fn save(&self, filename: String) -> bool {
let format = match Path::new(&filename).extension().unwrap().to_str().unwrap() {
"json" => egobox_moe::GpFileFormat::Json,
_ => egobox_moe::GpFileFormat::Binary,
};
self.0.save(&filename, format).is_ok()
}

/// Load Gaussian processes mixture from a json file.
/// Load Gaussian processes mixture from file.
///
/// Parameters
/// filename (string)
/// json filepath generated by saving a trained Gaussian processes mixture
///
#[staticmethod]
fn load(filename: String) -> Gpx {
Gpx(GpMixture::load(&filename).unwrap())
let format = match Path::new(&filename).extension().unwrap().to_str().unwrap() {
"json" => egobox_moe::GpFileFormat::Json,
_ => egobox_moe::GpFileFormat::Binary,
};
Gpx(GpMixture::load(&filename, format).unwrap())
}

/// Predict output values at nsamples points.
Expand Down
Loading

0 comments on commit bda92f9

Please sign in to comment.