Skip to content

Commit

Permalink
[GSoC] New Interface report_metrics in Python SDK (kubeflow#2371)
Browse files Browse the repository at this point in the history
* chore: add report_metrics.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: modify the code according to the first review.

Signed-off-by: Electronic-Waste <[email protected]>

* chore: add validation for metrics value & rename katib_report_metrics.py to report_metrics.py.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: update import path in __init__.py.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: delete blank line.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: update RuntimeError doc string & correct spelling error & add new line.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: delete blank in the last line.

Signed-off-by: Electronic-Waste <[email protected]>

---------

Signed-off-by: Electronic-Waste <[email protected]>
  • Loading branch information
Electronic-Waste authored and shashank-iitbhu committed Jul 25, 2024
1 parent e678ab8 commit f94e20e
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 1 deletion.
2 changes: 2 additions & 0 deletions sdk/python/v1beta1/kubeflow/katib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@

# Import Katib API client.
from kubeflow.katib.api.katib_client import KatibClient
# Import Katib report metrics functions
from kubeflow.katib.api.report_metrics import report_metrics
# Import Katib helper functions.
import kubeflow.katib.api.search as search
# Import Katib helper constants.
Expand Down
85 changes: 85 additions & 0 deletions sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2024 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from datetime import datetime, timezone
from typing import Any, Dict

import grpc
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
from kubeflow.katib.constants import constants
from kubeflow.katib.utils import utils

def report_metrics(
metrics: Dict[str, Any],
db_manager_address: str = constants.DEFAULT_DB_MANAGER_ADDRESS,
timeout: int = constants.DEFAULT_TIMEOUT,
):
"""Push Metrics Directly to Katib DB
Katib always passes Trial name as env variable `KATIB_TRIAL_NAME` to the training container.
Args:
metrics: Dict of metrics pushed to Katib DB.
For examle, `metrics = {"loss": 0.01, "accuracy": 0.99}`.
db-manager-address: Address for the Katib DB Manager in this format: `ip-address:port`.
timeout: Optional, gRPC API Server timeout in seconds to report metrics.
Raises:
ValueError: The Trial name is not passed to environment variables.
RuntimeError: Unable to push Trial metrics to Katib DB or
metrics value has incorrect format (cannot be converted to type `float`).
"""

# Get Trial's namespace and name
namespace = utils.get_current_k8s_namespace()
name = os.getenv("KATIB_TRIAL_NAME")
if name is None:
raise ValueError(
"The Trial name is not passed to environment variables"
)

# Get channel for grpc call to db manager
db_manager_address = db_manager_address.split(":")
channel = grpc.beta.implementations.insecure_channel(
db_manager_address[0], int(db_manager_address[1])
)

# Validate metrics value in dict
for value in metrics.values():
utils.validate_metrics_value(value)

# Dial katib db manager to report metrics
with katib_api_pb2.beta_create_DBManager_stub(channel) as client:
try:
timestamp = datetime.now(timezone.utc).strftime(constants.RFC3339_FORMAT)
client.ReportObservationLog(
request=katib_api_pb2.ReportObservationLogRequest(
trial_name=name,
observation_logs=katib_api_pb2.ObservationLog(
metric_logs=[
katib_api_pb2.MetricLog(
time_stamp=timestamp,
metric=katib_api_pb2.Metric(name=name,value=str(value))
)
for name, value in metrics.items()
]
)
),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}"
)
3 changes: 3 additions & 0 deletions sdk/python/v1beta1/kubeflow/katib/constants/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# How long to wait in seconds for requests to the Kubernetes or gRPC API Server.
DEFAULT_TIMEOUT = 120

# RFC3339 time format
RFC3339_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

# Global CRD version
KATIB_VERSION = os.environ.get("EXPERIMENT_VERSION", "v1beta1")

Expand Down
12 changes: 11 additions & 1 deletion sdk/python/v1beta1/kubeflow/katib/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import json
import os
import textwrap
from typing import Callable
from typing import Callable, Any
import inspect

from kubeflow.katib import models
Expand Down Expand Up @@ -72,6 +72,16 @@ def print_experiment_status(experiment: models.V1beta1Experiment):
print(f"Current Optimal Trial:\n {experiment.status.current_optimal_trial}")
print(f"Experiment conditions:\n {experiment.status.conditions}")

def validate_metrics_value(value: Any):
"""Validate if the metrics value can be converted to type `float`."""
try:
float(value)
except Exception:
raise ValueError(
f"Invalid value {value} for metrics value. "
"The metrics value should have or can be converted to type `float`. "
)


def validate_objective_function(objective: Callable):

Expand Down

0 comments on commit f94e20e

Please sign in to comment.