From a524f33830e02189476efaf6d9045cbd2ce605f0 Mon Sep 17 00:00:00 2001 From: Shao Wang <77665902+Electronic-Waste@users.noreply.github.com> Date: Fri, 23 Aug 2024 19:12:58 +0800 Subject: [PATCH] [SDK] fix grpc related bugs in Python SDK (#2398) * fix: fix bugs in report_metrics. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix: fix bugs in tune. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix: fix bugs in get_trial_metrics. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix: update .gitignore and setup.py. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix: update Makefile. Signed-off-by: Electronic-Waste <2690692950@qq.com> * feat: add report_metrics_test.py. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix: fix lint error. Signed-off-by: Electronic-Waste <2690692950@qq.com> * feat: add UTs for get_trial_metrics. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix: update post_gen.py. Signed-off-by: Electronic-Waste <2690692950@qq.com> * refactor: rebase to master. Signed-off-by: Electronic-Waste <2690692950@qq.com> * test(sdk): use single katib_client. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(sdk): add TODO for import rewrite. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(sdk): fix lint error with black. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(sdk): fix lint error with isort. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(sdk): reformat import in katib_client_test.py. Signed-off-by: Electronic-Waste <2690692950@qq.com> --------- Signed-off-by: Electronic-Waste <2690692950@qq.com> --- Makefile | 6 +- hack/gen-python-sdk/post_gen.py | 4 +- sdk/python/v1beta1/.gitignore | 1 + sdk/python/v1beta1/kubeflow/katib/__init__.py | 4 +- .../kubeflow/katib/api/katib_client.py | 30 +++-- .../kubeflow/katib/api/katib_client_test.py | 71 +++++++++++- .../kubeflow/katib/api/report_metrics.py | 54 +++++---- .../kubeflow/katib/api/report_metrics_test.py | 104 ++++++++++++++++++ sdk/python/v1beta1/setup.py | 16 +++ 9 files changed, 240 insertions(+), 50 deletions(-) create mode 100644 sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py diff --git a/Makefile b/Makefile index 3edfc9ece13..a6708de7f5b 100755 --- a/Makefile +++ b/Makefile @@ -166,13 +166,17 @@ ifeq ("$(wildcard $(TEST_TENSORFLOW_EVENT_FILE_PATH))", "") python examples/v1beta1/trial-images/tf-mnist-with-summaries/mnist.py --epochs 5 --batch-size 200 --log-path $(TEST_TENSORFLOW_EVENT_FILE_PATH) endif +# TODO(Electronic-Waste): Remove the import rewrite when protobuf supports `python_package` option. +# REF: https://github.com/protocolbuffers/protobuf/issues/7061 pytest: prepare-pytest prepare-pytest-testdata pytest ./test/unit/v1beta1/suggestion --ignore=./test/unit/v1beta1/suggestion/test_skopt_service.py pytest ./test/unit/v1beta1/earlystopping pytest ./test/unit/v1beta1/metricscollector cp ./pkg/apis/manager/v1beta1/python/api_pb2.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py + cp ./pkg/apis/manager/v1beta1/python/api_pb2_grpc.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py + sed -i "s/api_pb2/kubeflow\.katib\.katib_api_pb2/g" ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py pytest ./sdk/python/v1beta1/kubeflow/katib - rm ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py + rm ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py # The skopt service doesn't work appropriately with Python 3.11. # So, we need to run the test with Python 3.9. diff --git a/hack/gen-python-sdk/post_gen.py b/hack/gen-python-sdk/post_gen.py index b56d259b3ac..70eab3a2595 100644 --- a/hack/gen-python-sdk/post_gen.py +++ b/hack/gen-python-sdk/post_gen.py @@ -41,8 +41,8 @@ def _rewrite_helper(input_file, output_file, rewrite_rules): if output_file == "sdk/python/v1beta1/kubeflow/katib/__init__.py": lines.append("# Import Katib API client.\n") lines.append("from kubeflow.katib.api.katib_client import KatibClient\n") - lines.append("# Import Katib report metrics functions") - lines.append("from kubeflow.katib.api.report_metrics import report_metrics") + lines.append("# Import Katib report metrics functions\n") + lines.append("from kubeflow.katib.api.report_metrics import report_metrics\n") lines.append("# Import Katib helper functions.\n") lines.append("import kubeflow.katib.api.search as search\n") lines.append("# Import Katib helper constants.\n") diff --git a/sdk/python/v1beta1/.gitignore b/sdk/python/v1beta1/.gitignore index 81f90cfca9b..64aa80409b8 100644 --- a/sdk/python/v1beta1/.gitignore +++ b/sdk/python/v1beta1/.gitignore @@ -3,3 +3,4 @@ dist/ # Katib gRPC APIs kubeflow/katib/katib_api_pb2.py +kubeflow/katib/katib_api_pb2_grpc.py diff --git a/sdk/python/v1beta1/kubeflow/katib/__init__.py b/sdk/python/v1beta1/kubeflow/katib/__init__.py index 0ea206ea7ca..7aef4c9897d 100644 --- a/sdk/python/v1beta1/kubeflow/katib/__init__.py +++ b/sdk/python/v1beta1/kubeflow/katib/__init__.py @@ -71,7 +71,9 @@ # Import Katib API client. from kubeflow.katib.api.katib_client import KatibClient -# Import Katib report metrics functionsfrom kubeflow.katib.api.report_metrics import report_metrics# Import Katib helper functions. +# Import Katib report metrics functions +from kubeflow.katib.api.report_metrics import report_metrics +# Import Katib helper functions. import kubeflow.katib.api.search as search # Import Katib helper constants. from kubeflow.katib.constants.constants import BASE_IMAGE_TENSORFLOW diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py index 18bb0bd26a9..78808d17f05 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py @@ -21,6 +21,7 @@ import grpc import kubeflow.katib.katib_api_pb2 as katib_api_pb2 +import kubeflow.katib.katib_api_pb2_grpc as katib_api_pb2_grpc from kubeflow.katib import models from kubeflow.katib.api_client import ApiClient from kubeflow.katib.constants import constants @@ -1305,21 +1306,18 @@ def get_trial_metrics( namespace = namespace or self.namespace - db_manager_address = db_manager_address.split(":") - channel = grpc.beta.implementations.insecure_channel( - db_manager_address[0], int(db_manager_address[1]) - ) + channel = grpc.insecure_channel(db_manager_address) - with katib_api_pb2.beta_create_DBManager_stub(channel) as client: - try: - # When metric name is empty, we select all logs from the Katib DB. - observation_logs = client.GetObservationLog( - katib_api_pb2.GetObservationLogRequest(trial_name=name), - timeout=timeout, - ) - except Exception as e: - raise RuntimeError( - f"Unable to get metrics for Trial {namespace}/{name}. Exception: {e}" - ) + client = katib_api_pb2_grpc.DBManagerStub(channel) + try: + # When metric name is empty, we select all logs from the Katib DB. + observation_logs = client.GetObservationLog( + katib_api_pb2.GetObservationLogRequest(trial_name=name), + timeout=timeout, + ) + except Exception as e: + raise RuntimeError( + f"Unable to get metrics for Trial {namespace}/{name}. Exception: {e}" + ) - return observation_logs.observation_log.metric_logs + return observation_logs.observation_log.metric_logs diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py index f02728f4413..fef18adfa0f 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py @@ -2,6 +2,7 @@ from typing import List, Optional from unittest.mock import Mock, patch +import kubeflow.katib.katib_api_pb2 as katib_api_pb2 import pytest from kubeflow.katib import ( KatibClient, @@ -38,6 +39,24 @@ def create_namespaced_custom_object_response(*args, **kwargs): return {"metadata": {"name": "12345-experiment-mnist-ci-test"}} +def get_observation_log_response(*args, **kwargs): + if kwargs.get("timeout") == 0: + raise TimeoutError + elif args[0].trial_name == "invalid": + raise RuntimeError + else: + return katib_api_pb2.GetObservationLogReply( + observation_log=katib_api_pb2.ObservationLog( + metric_logs=[ + katib_api_pb2.MetricLog( + time_stamp="2024-07-29T15:09:08Z", + metric=katib_api_pb2.Metric(name="result", value="0.99"), + ) + ] + ) + ) + + def generate_trial_template() -> V1beta1TrialTemplate: trial_spec = { "apiVersion": "batch/v1", @@ -223,6 +242,34 @@ def create_experiment( ] +test_get_trial_metrics_data = [ + ( + "valid trial name", + {"name": "example", "namespace": "valid", "timeout": constants.DEFAULT_TIMEOUT}, + [ + katib_api_pb2.MetricLog( + time_stamp="2024-07-29T15:09:08Z", + metric=katib_api_pb2.Metric(name="result", value="0.99"), + ) + ], + ), + ( + "invalid trial name", + { + "name": "invalid", + "namespace": "invalid", + "timeout": constants.DEFAULT_TIMEOUT, + }, + RuntimeError, + ), + ( + "GetObservationLog timeout error", + {"name": "example", "namespace": "valid", "timeout": 0}, + RuntimeError, + ), +] + + @pytest.fixture def katib_client(): with patch( @@ -232,7 +279,12 @@ def katib_client(): side_effect=create_namespaced_custom_object_response ) ), - ), patch("kubernetes.config.load_kube_config", return_value=Mock()): + ), patch("kubernetes.config.load_kube_config", return_value=Mock()), patch( + "kubeflow.katib.katib_api_pb2_grpc.DBManagerStub", + return_value=Mock( + GetObservationLog=Mock(side_effect=get_observation_log_response) + ), + ): client = KatibClient() yield client @@ -251,3 +303,20 @@ def test_create_experiment(katib_client, test_name, kwargs, expected_output): except Exception as e: assert type(e) is expected_output print("test execution complete") + + +@pytest.mark.parametrize( + "test_name,kwargs,expected_output", test_get_trial_metrics_data +) +def test_get_trial_metrics(katib_client, test_name, kwargs, expected_output): + """ + test get_trial_metrics function of katib client + """ + print("\n\nExecuting test:", test_name) + try: + metrics = katib_client.get_trial_metrics(**kwargs) + for i in range(len(metrics)): + assert metrics[i] == expected_output[i] + except Exception as e: + assert type(e) is expected_output + print("test execution complete") diff --git a/sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py b/sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py index 5e5f2996f5d..aec62d3f6a7 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py @@ -18,6 +18,7 @@ import grpc import kubeflow.katib.katib_api_pb2 as katib_api_pb2 +import kubeflow.katib.katib_api_pb2_grpc as katib_api_pb2_grpc from kubeflow.katib.constants import constants from kubeflow.katib.utils import utils @@ -38,9 +39,9 @@ def report_metrics( timeout: Optional, gRPC API Server timeout in seconds to report metrics. Raises: - ValueError: The Trial name is not passed to environment variables. - RuntimeError: Unable to push Trial metrics to Katib DB or + ValueError: The Trial name is not passed to environment variables or metrics value has incorrect format (cannot be converted to type `float`). + RuntimeError: Unable to push Trial metrics to Katib DB. """ # Get Trial's namespace and name @@ -50,37 +51,32 @@ def report_metrics( raise ValueError("The Trial name is not passed to environment variables") # Get channel for grpc call to db manager - db_manager_address = db_manager_address.split(":") - channel = grpc.beta.implementations.insecure_channel( - db_manager_address[0], int(db_manager_address[1]) - ) + channel = grpc.insecure_channel(db_manager_address) # Validate metrics value in dict for value in metrics.values(): utils.validate_metrics_value(value) # Dial katib db manager to report metrics - with katib_api_pb2.beta_create_DBManager_stub(channel) as client: - try: - timestamp = datetime.now(timezone.utc).strftime(constants.RFC3339_FORMAT) - client.ReportObservationLog( - request=katib_api_pb2.ReportObservationLogRequest( - trial_name=name, - observation_logs=katib_api_pb2.ObservationLog( - metric_logs=[ - katib_api_pb2.MetricLog( - time_stamp=timestamp, - metric=katib_api_pb2.Metric( - name=name, value=str(value) - ), - ) - for name, value in metrics.items() - ] - ), + client = katib_api_pb2_grpc.DBManagerStub(channel) + try: + timestamp = datetime.now(timezone.utc).strftime(constants.RFC3339_FORMAT) + client.ReportObservationLog( + request=katib_api_pb2.ReportObservationLogRequest( + trial_name=name, + observation_log=katib_api_pb2.ObservationLog( + metric_logs=[ + katib_api_pb2.MetricLog( + time_stamp=timestamp, + metric=katib_api_pb2.Metric(name=name, value=str(value)), + ) + for name, value in metrics.items() + ] ), - timeout=timeout, - ) - except Exception as e: - raise RuntimeError( - f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}" - ) + ), + timeout=timeout, + ) + except Exception as e: + raise RuntimeError( + f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}" + ) diff --git a/sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py b/sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py new file mode 100644 index 00000000000..4ceba92ef2a --- /dev/null +++ b/sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py @@ -0,0 +1,104 @@ +from unittest.mock import patch + +import pytest +from kubeflow.katib import report_metrics +from kubeflow.katib.constants import constants + +TEST_RESULT_SUCCESS = "success" +ENV_VARIABLE_EMPTY = True +ENV_VARIABLE_NOT_EMPTY = False + + +def report_observation_log_response(*args, **kwargs): + if kwargs.get("timeout") == 0: + raise TimeoutError + + +test_report_metrics_data = [ + ( + "valid metrics with float type", + {"metrics": {"result": 0.99}, "timeout": constants.DEFAULT_TIMEOUT}, + TEST_RESULT_SUCCESS, + ENV_VARIABLE_NOT_EMPTY, + ), + ( + "valid metrics with string type", + {"metrics": {"result": "0.99"}, "timeout": constants.DEFAULT_TIMEOUT}, + TEST_RESULT_SUCCESS, + ENV_VARIABLE_NOT_EMPTY, + ), + ( + "valid metrics with int type", + {"metrics": {"result": 1}, "timeout": constants.DEFAULT_TIMEOUT}, + TEST_RESULT_SUCCESS, + ENV_VARIABLE_NOT_EMPTY, + ), + ( + "ReportObservationLog timeout error", + {"metrics": {"result": 0.99}, "timeout": 0}, + RuntimeError, + ENV_VARIABLE_NOT_EMPTY, + ), + ( + "invalid metrics with type string", + {"metrics": {"result": "abc"}, "timeout": constants.DEFAULT_TIMEOUT}, + ValueError, + ENV_VARIABLE_NOT_EMPTY, + ), + ( + "Trial name is not passed to env variables", + {"metrics": {"result": 0.99}, "timeout": constants.DEFAULT_TIMEOUT}, + ValueError, + ENV_VARIABLE_EMPTY, + ), +] + + +@pytest.fixture +def mock_getenv(request): + with patch("os.getenv") as mock: + if request.param is ENV_VARIABLE_EMPTY: + mock.side_effect = ValueError + else: + mock.return_value = "example" + yield mock + + +@pytest.fixture +def mock_get_current_k8s_namespace(): + with patch("kubeflow.katib.utils.utils.get_current_k8s_namespace") as mock: + mock.return_value = "test" + yield mock + + +@pytest.fixture +def mock_report_observation_log(): + with patch("kubeflow.katib.katib_api_pb2_grpc.DBManagerStub") as mock: + mock_instance = mock.return_value + mock_instance.ReportObservationLog.side_effect = report_observation_log_response + yield mock_instance + + +@pytest.mark.parametrize( + "test_name,kwargs,expected_output,mock_getenv", + test_report_metrics_data, + indirect=["mock_getenv"], +) +def test_report_metrics( + test_name, + kwargs, + expected_output, + mock_getenv, + mock_get_current_k8s_namespace, + mock_report_observation_log, +): + """ + test report_metrics function + """ + print("\n\nExecuting test:", test_name) + try: + report_metrics(**kwargs) + assert expected_output == TEST_RESULT_SUCCESS + except Exception as e: + assert type(e) is expected_output + print("test execution complete") diff --git a/sdk/python/v1beta1/setup.py b/sdk/python/v1beta1/setup.py index 6b9c152f2d2..49c689a235c 100644 --- a/sdk/python/v1beta1/setup.py +++ b/sdk/python/v1beta1/setup.py @@ -28,6 +28,7 @@ ] katib_grpc_api_file = "../../../pkg/apis/manager/v1beta1/python/api_pb2.py" +katib_grpc_svc_file = "../../../pkg/apis/manager/v1beta1/python/api_pb2_grpc.py" # Copy Katib gRPC Python APIs to use it in the Katib SDK Client. # We need to always copy this file only on the SDK building stage, not on SDK installation stage. @@ -37,6 +38,21 @@ "kubeflow/katib/katib_api_pb2.py", ) +# TODO(Electronic-Waste): Remove the import rewrite when protobuf supports `python_package` option. +# REF: https://github.com/protocolbuffers/protobuf/issues/7061 +if os.path.exists(katib_grpc_svc_file): + shutil.copy( + katib_grpc_svc_file, + "kubeflow/katib/katib_api_pb2_grpc.py", + ) + + with open("kubeflow/katib/katib_api_pb2_grpc.py", "r+") as file: + content = file.read() + new_content = content.replace("api_pb2", "kubeflow.katib.katib_api_pb2") + file.seek(0) + file.write(new_content) + file.truncate() + setuptools.setup( name="kubeflow-katib", version="0.17.0",