diff --git a/rubicon_ml/schema/logger.py b/rubicon_ml/schema/logger.py index f8333aff..203038e1 100644 --- a/rubicon_ml/schema/logger.py +++ b/rubicon_ml/schema/logger.py @@ -184,7 +184,11 @@ def log_with_schema( if "self" in artifact: logging_func_name = artifact["self"] logging_func = getattr(experiment, logging_func_name) - logging_func(obj) + + # Get remaining artifact logging function parameters and run with func + logging_func( + obj, **dict((k, v) for k, v in artifact.items() if k != "self") + ) # key-values in rest of dictionary are passed as arguments else: data_object = _get_data_object(obj, artifact) diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 73fe4575..2d3c5b61 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -3,6 +3,7 @@ import pytest from h2o import H2OFrame from h2o.estimators.gbm import H2OGradientBoostingEstimator +from h2o.estimators.generic import H2OGenericEstimator from h2o.estimators.glm import H2OGeneralizedLinearEstimator from h2o.estimators.random_forest import H2ORandomForestEstimator from h2o.estimators.targetencoder import H2OTargetEncoderEstimator @@ -12,6 +13,8 @@ from xgboost import XGBClassifier, XGBRegressor from xgboost.dask import DaskXGBClassifier, DaskXGBRegressor +from rubicon_ml.schema import registry + PANDAS_SCHEMA_CLS = [ LGBMClassifier, LGBMRegressor, @@ -40,12 +43,15 @@ def _fit_and_log(X, y, schema_cls, rubicon_project): rubicon_project.log_with_schema(model) -def _train_and_log(X, y, schema_cls, rubicon_project): +def _train_and_log(X, y, schema_cls, rubicon_project, schema=None): target_name = "target" training_frame = pd.concat([X, pd.Series(y)], axis=1) training_frame.columns = [*X.columns, target_name] training_frame_h2o = H2OFrame(training_frame) + if schema: + rubicon_project.set_schema(schema) + model = schema_cls() model.train( training_frame=training_frame_h2o, @@ -97,8 +103,10 @@ def test_estimator_schema_fit_dask_df( @pytest.mark.integration @pytest.mark.parametrize("schema_cls", H2O_SCHEMA_CLS) +@pytest.mark.parametrize("extended_schema", [True, False]) def test_estimator_h2o_schema_train( schema_cls, + extended_schema, make_classification_df, rubicon_local_filesystem_client_with_project, ): @@ -107,10 +115,42 @@ def test_estimator_h2o_schema_train( X, y = make_classification_df y = y > y.mean() - experiment = _train_and_log(X, y, schema_cls, project) - model_artifact = experiment.artifact(name=schema_cls.__name__) - - assert len(project.schema_["parameters"]) == len(experiment.parameters()) + # H2OTargetEncoderEstimator does not support MOJO + if not extended_schema or schema_cls == H2OTargetEncoderEstimator: + use_mojo = False + deserialize_method = "h2o_binary" + artifact_name = schema_cls.__name__ + else: + use_mojo = True + deserialize_method = "h2o_mojo" + artifact_name = H2OGenericEstimator.__name__ + + if extended_schema: + schema = { + "name": f"h2o__{schema_cls.__name__}__ext", + "extends": f"h2o__{schema_cls.__name__}", + "artifacts": [ + { + "self": "log_h2o_model", + "artifact_name": artifact_name, + "export_cross_validation_predictions": True, + "use_mojo": use_mojo, + }, + ], + } + else: + schema = None + + experiment = _train_and_log(X, y, schema_cls, project, schema) + model_artifact = experiment.artifact(name=artifact_name) + + if extended_schema: + # Make sure the extended schema parameters are set properly with the schema from registry + assert len(registry.get_schema(f"h2o__{schema_cls.__name__}")["parameters"]) == len( + experiment.parameters() + ) + else: + assert len(project.schema_["parameters"]) == len(experiment.parameters()) assert ( - model_artifact.get_data(deserialize="h2o_binary").__class__.__name__ == schema_cls.__name__ + model_artifact.get_data(deserialize=deserialize_method).__class__.__name__ == artifact_name ) diff --git a/tests/unit/schema/test_schema_logger.py b/tests/unit/schema/test_schema_logger.py index e7b94cd4..7df3066e 100644 --- a/tests/unit/schema/test_schema_logger.py +++ b/tests/unit/schema/test_schema_logger.py @@ -110,10 +110,10 @@ def test_log_artifacts_with_schema(objects_to_log, rubicon_project, artifact_sch object_b.__class__, ) - def custom_logging_func(self, obj): + def custom_logging_func(self, obj, test_param): self.custom_logging_func_called = True - artifact_schema["artifacts"].append({"self": "custom_logging_func"}) + artifact_schema["artifacts"].append({"self": "custom_logging_func", "test_param": "test"}) with mock.patch.object( rubicon_ml.client.experiment.Experiment,