From 5a3e23da6ef334a771a51c7a2a7c6c7f7b125602 Mon Sep 17 00:00:00 2001 From: thebrianbn Date: Thu, 26 Sep 2024 14:01:43 -0400 Subject: [PATCH 01/10] feat: h2o model logging supports MOJO --- rubicon_ml/client/mixin.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index cf7074c2..7969a145 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -244,6 +244,7 @@ def log_h2o_model( h2o_model, artifact_name: Optional[str] = None, export_cross_validation_predictions: bool = False, + use_mojo: bool = False, **log_artifact_kwargs, ) -> Artifact: """Log an `h2o` model as an artifact using `h2o.save_model`. @@ -256,6 +257,9 @@ def log_h2o_model( The name of the artifact. Defaults to None, using `h2o_model`'s class name. export_cross_validation_predictions: bool, optional (default False) Passed directly to `h2o.save_model`. + use_mojo: bool, optional (default False) + Whether to log the model in MOJO format. If False, the model will be + logged in binary format. log_artifact_kwargs : dict Additional kwargs to be passed directly to `self.log_artifact`. """ @@ -268,12 +272,16 @@ def log_h2o_model( artifact_name = h2o_model.__class__.__name__ with tempfile.TemporaryDirectory() as temp_dir_name: - model_data_path = h2o.save_model( - h2o_model, - export_cross_validation_predictions=export_cross_validation_predictions, - filename=artifact_name, - path=temp_dir_name, - ) + if use_mojo: + model_data_path = f"{temp_dir_name}/{artifact_name}.zip" + h2o_model.save_mojo(path=model_data_path) + else: + model_data_path = h2o.save_model( + h2o_model, + export_cross_validation_predictions=export_cross_validation_predictions, + filename=artifact_name, + path=temp_dir_name, + ) artifact = self.log_artifact( name=artifact_name, From 8227c45c102ed71dc26ba2e2377b0ccaacdb69ae Mon Sep 17 00:00:00 2001 From: thebrianbn Date: Thu, 26 Sep 2024 14:01:51 -0400 Subject: [PATCH 02/10] feat: add support for deserializing of h2o MOJO artifacts --- rubicon_ml/client/artifact.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index 8ffcbdde..16a18a5b 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -71,7 +71,7 @@ def _get_data(self): @failsafe def get_data( self, - deserialize: Optional[Literal["h2o", "pickle", "xgboost"]] = None, + deserialize: Optional[Literal["h2o_binary", "h2o_mojo", "pickle", "xgboost"]] = None, unpickle: bool = False, # TODO: deprecate & move to `deserialize` ): """Loads the data associated with this artifact and @@ -82,7 +82,8 @@ def get_data( deseralize : str, optional Method to use to deseralize this artifact's data. * None to disable deseralization and return the raw data. - * "h2o" to use `h2o.load_model` to load the data. + * "h2o_binary" to use `h2o.load_model` to load the data. + * "h2o_mojo" to use `h2o.import_mojo` to load the data. * "pickle" to use pickles to load the data. * "xgboost" to use xgboost's JSON loader to load the data as a fitted model. Defaults to None. @@ -119,12 +120,18 @@ def get_data( except Exception as err: return_err = err else: - if deserialize == "h2o": + if deserialize == "h2o_binary": import h2o data = h2o.load_model( repo._get_artifact_data_path(project_name, experiment_id, self.id) ) + elif deserialize == "h2o_mojo": + import h2o + + data = h2o.import_mojo( + repo._get_artifact_data_path(project_name, experiment_id, self.id) + ) elif deserialize == "pickle": data = pickle.loads(data) From 240b2543e55a9c6c4e0d0fa647a40f7f9850c9a6 Mon Sep 17 00:00:00 2001 From: thebrianbn Date: Thu, 26 Sep 2024 14:01:59 -0400 Subject: [PATCH 03/10] fix: log MOJO model as directory --- rubicon_ml/client/mixin.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index 7969a145..f71efc96 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -275,6 +275,12 @@ def log_h2o_model( if use_mojo: model_data_path = f"{temp_dir_name}/{artifact_name}.zip" h2o_model.save_mojo(path=model_data_path) + + artifact = self.log_artifact( + name=artifact_name, + data_directory=model_data_path, + **log_artifact_kwargs, + ) else: model_data_path = h2o.save_model( h2o_model, @@ -283,11 +289,11 @@ def log_h2o_model( path=temp_dir_name, ) - artifact = self.log_artifact( - name=artifact_name, - data_path=model_data_path, - **log_artifact_kwargs, - ) + artifact = self.log_artifact( + name=artifact_name, + data_path=model_data_path, + **log_artifact_kwargs, + ) return artifact From 385696ea14f3b5297efa4b90a85e4254f930a97f Mon Sep 17 00:00:00 2001 From: thebrianbn Date: Thu, 26 Sep 2024 14:02:07 -0400 Subject: [PATCH 04/10] test: update unit test to test for binary vs MOJO --- tests/unit/client/test_mixin_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/client/test_mixin_client.py b/tests/unit/client/test_mixin_client.py index fdcddcca..f08cfc01 100644 --- a/tests/unit/client/test_mixin_client.py +++ b/tests/unit/client/test_mixin_client.py @@ -201,8 +201,8 @@ def test_log_json(project_client): assert artifact_a.id in [a.id for a in artifacts] assert artifact_b.id in [a.id for a in artifacts] - -def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_with_project): +@pytest.mark.parametrize("use_mojo", [False, True]) +def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_with_project, use_mojo): """Test logging `h2o` model data.""" _, project = rubicon_local_filesystem_client_with_project X, y = make_classification_df @@ -222,7 +222,7 @@ def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_w y=target_name, ) - artifact = project.log_h2o_model(h2o_model, tags=["h2o"]) + artifact = project.log_h2o_model(h2o_model, use_mojo=use_mojo, tags=["h2o"]) read_artifact = project.artifact(name=artifact.name) assert artifact.id == read_artifact.id From c95139ee2dc50255d8d8f15dd3d4b9276ac54410 Mon Sep 17 00:00:00 2001 From: thebrianbn Date: Thu, 26 Sep 2024 14:02:15 -0400 Subject: [PATCH 05/10] test: update h2o schema integration test to make sure binary model is used --- tests/integration/test_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 6aa82054..5120f5a7 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -111,4 +111,4 @@ def test_estimator_h2o_schema_train( model_artifact = experiment.artifact(name=schema_cls.__name__) assert len(project.schema_["parameters"]) == len(experiment.parameters()) - assert model_artifact.get_data(deserialize="h2o").__class__.__name__ == schema_cls.__name__ + assert model_artifact.get_data(deserialize="h2o_binary").__class__.__name__ == schema_cls.__name__ From dd54552c2c41389403f9f37bff7fa4263723cd1c Mon Sep 17 00:00:00 2001 From: thebrianbn Date: Thu, 26 Sep 2024 14:02:22 -0400 Subject: [PATCH 06/10] fix: download_mojo instead of save_mojo --- rubicon_ml/client/mixin.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index f71efc96..fb362162 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -274,13 +274,7 @@ def log_h2o_model( with tempfile.TemporaryDirectory() as temp_dir_name: if use_mojo: model_data_path = f"{temp_dir_name}/{artifact_name}.zip" - h2o_model.save_mojo(path=model_data_path) - - artifact = self.log_artifact( - name=artifact_name, - data_directory=model_data_path, - **log_artifact_kwargs, - ) + h2o_model.download_mojo(path=model_data_path) else: model_data_path = h2o.save_model( h2o_model, @@ -289,11 +283,11 @@ def log_h2o_model( path=temp_dir_name, ) - artifact = self.log_artifact( - name=artifact_name, - data_path=model_data_path, - **log_artifact_kwargs, - ) + artifact = self.log_artifact( + name=artifact_name, + data_path=model_data_path, + **log_artifact_kwargs, + ) return artifact From d12f0378b88a5b1d8aff54f2a3b5934d81804b40 Mon Sep 17 00:00:00 2001 From: thebrianbn Date: Thu, 26 Sep 2024 14:02:30 -0400 Subject: [PATCH 07/10] test: add unit test for retrieval of MOJO model --- tests/unit/client/test_artifact_client.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/unit/client/test_artifact_client.py b/tests/unit/client/test_artifact_client.py index 10135cfd..da7d42b1 100644 --- a/tests/unit/client/test_artifact_client.py +++ b/tests/unit/client/test_artifact_client.py @@ -9,6 +9,7 @@ import xgboost from h2o import H2OFrame from h2o.estimators.random_forest import H2ORandomForestEstimator +from h2o.estimators.generic import H2OGenericEstimator from rubicon_ml import domain from rubicon_ml.client import Artifact, Rubicon @@ -159,8 +160,15 @@ def test_download_location(mock_open, project_client): mock_file().write.assert_called_once_with(data) +@pytest.mark.parametrize( + ["use_mojo", "deserialization_method"], + [ + (False, "h2o_binary"), + (True, "h2o_mojo"), + ], +) def test_get_data_deserialize_h2o( - make_classification_df, rubicon_local_filesystem_client_with_project + make_classification_df, rubicon_local_filesystem_client_with_project, use_mojo, deserialization_method ): """Test logging `h2o` model data.""" _, project = rubicon_local_filesystem_client_with_project @@ -181,10 +189,13 @@ def test_get_data_deserialize_h2o( y=target_name, ) - artifact = project.log_h2o_model(h2o_model) - artifact_data = artifact.get_data(deserialize="h2o") + artifact = project.log_h2o_model(h2o_model, use_mojo=use_mojo) + artifact_data = artifact.get_data(deserialize=deserialization_method) - assert artifact_data.__class__ == h2o_model.__class__ + if use_mojo: + assert isinstance(artifact_data, H2OGenericEstimator) + else: + assert artifact_data.__class__ == h2o_model.__class__ def test_get_data_deserialize_xgboost( From 417453442c2124d5404614022aff5baf2b64074c Mon Sep 17 00:00:00 2001 From: thebrianbn Date: Mon, 30 Sep 2024 10:20:21 -0400 Subject: [PATCH 08/10] chore: pre-commit --- tests/integration/test_schema.py | 4 +++- tests/unit/client/test_artifact_client.py | 7 +++++-- tests/unit/client/test_mixin_client.py | 5 ++++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 5120f5a7..73fe4575 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -111,4 +111,6 @@ def test_estimator_h2o_schema_train( model_artifact = experiment.artifact(name=schema_cls.__name__) assert len(project.schema_["parameters"]) == len(experiment.parameters()) - assert model_artifact.get_data(deserialize="h2o_binary").__class__.__name__ == schema_cls.__name__ + assert ( + model_artifact.get_data(deserialize="h2o_binary").__class__.__name__ == schema_cls.__name__ + ) diff --git a/tests/unit/client/test_artifact_client.py b/tests/unit/client/test_artifact_client.py index da7d42b1..32642492 100644 --- a/tests/unit/client/test_artifact_client.py +++ b/tests/unit/client/test_artifact_client.py @@ -8,8 +8,8 @@ import pytest import xgboost from h2o import H2OFrame -from h2o.estimators.random_forest import H2ORandomForestEstimator from h2o.estimators.generic import H2OGenericEstimator +from h2o.estimators.random_forest import H2ORandomForestEstimator from rubicon_ml import domain from rubicon_ml.client import Artifact, Rubicon @@ -168,7 +168,10 @@ def test_download_location(mock_open, project_client): ], ) def test_get_data_deserialize_h2o( - make_classification_df, rubicon_local_filesystem_client_with_project, use_mojo, deserialization_method + make_classification_df, + rubicon_local_filesystem_client_with_project, + use_mojo, + deserialization_method, ): """Test logging `h2o` model data.""" _, project = rubicon_local_filesystem_client_with_project diff --git a/tests/unit/client/test_mixin_client.py b/tests/unit/client/test_mixin_client.py index f08cfc01..aad16bf7 100644 --- a/tests/unit/client/test_mixin_client.py +++ b/tests/unit/client/test_mixin_client.py @@ -201,8 +201,11 @@ def test_log_json(project_client): assert artifact_a.id in [a.id for a in artifacts] assert artifact_b.id in [a.id for a in artifacts] + @pytest.mark.parametrize("use_mojo", [False, True]) -def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_with_project, use_mojo): +def test_log_h2o_model( + make_classification_df, rubicon_local_filesystem_client_with_project, use_mojo +): """Test logging `h2o` model data.""" _, project = rubicon_local_filesystem_client_with_project X, y = make_classification_df From 9232332e0b0fc70db6d51b268ed0c18fcd19ab05 Mon Sep 17 00:00:00 2001 From: thebrianbn Date: Mon, 30 Sep 2024 11:02:11 -0400 Subject: [PATCH 09/10] chore: backwards compatability for h2o --- rubicon_ml/client/artifact.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index 16a18a5b..147217e4 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -82,7 +82,7 @@ def get_data( deseralize : str, optional Method to use to deseralize this artifact's data. * None to disable deseralization and return the raw data. - * "h2o_binary" to use `h2o.load_model` to load the data. + * "h2o" or "h2o_binary" to use `h2o.load_model` to load the data. * "h2o_mojo" to use `h2o.import_mojo` to load the data. * "pickle" to use pickles to load the data. * "xgboost" to use xgboost's JSON loader to load the data as a fitted model. @@ -102,6 +102,13 @@ def get_data( ) deserialize = "pickle" + if deserialize == "h2o": + warnings.warn( + "'deserialize' method 'h2o' will be deprecated in a future release," + " please use 'h2o_binary' instead.", + DeprecationWarning, + ) + for repo in self.repositories or []: try: if deserialize == "xgboost": @@ -120,7 +127,10 @@ def get_data( except Exception as err: return_err = err else: - if deserialize == "h2o_binary": + if deserialize in [ + "h2o", + "h2o_binary", + ]: # "h2o" will be deprecated in a future release import h2o data = h2o.load_model( From 416d09f8fe0dba1f6ce286dd25f7a9d553d1c4eb Mon Sep 17 00:00:00 2001 From: thebrianbn Date: Mon, 30 Sep 2024 11:42:45 -0400 Subject: [PATCH 10/10] chore: update literals --- rubicon_ml/client/artifact.py | 2 +- tests/unit/client/test_artifact_client.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index 147217e4..56d3a8c1 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -71,7 +71,7 @@ def _get_data(self): @failsafe def get_data( self, - deserialize: Optional[Literal["h2o_binary", "h2o_mojo", "pickle", "xgboost"]] = None, + deserialize: Optional[Literal["h2o", "h2o_binary", "h2o_mojo", "pickle", "xgboost"]] = None, unpickle: bool = False, # TODO: deprecate & move to `deserialize` ): """Loads the data associated with this artifact and diff --git a/tests/unit/client/test_artifact_client.py b/tests/unit/client/test_artifact_client.py index 32642492..7b1c6f72 100644 --- a/tests/unit/client/test_artifact_client.py +++ b/tests/unit/client/test_artifact_client.py @@ -163,6 +163,7 @@ def test_download_location(mock_open, project_client): @pytest.mark.parametrize( ["use_mojo", "deserialization_method"], [ + (False, "h2o"), (False, "h2o_binary"), (True, "h2o_mojo"), ],