Skip to content

Commit

Permalink
improved typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Bslabe123 committed Jan 22, 2025
1 parent 4a506c1 commit 3cf004c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
4 changes: 2 additions & 2 deletions jetstream/core/config_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import dataclasses
import functools
from typing import Any, Callable, List, Tuple, Type
from typing import Any, Callable, List, Optional, Tuple, Type
from numpy import uint16

from jetstream.engine import engine_api
Expand Down Expand Up @@ -56,7 +56,7 @@ class InstantiatedEngines:
@dataclasses.dataclass
class MetricsServerConfig:
port: uint16
model_name: str
model_name: Optional[str]


# ▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼#
Expand Down
20 changes: 11 additions & 9 deletions jetstream/core/metrics/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import os
import re
from typing import Optional
import shortuuid
from prometheus_client import Counter, Gauge, Histogram
from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS
Expand All @@ -28,20 +29,21 @@ class JetstreamMetricsCollector:
"id": os.getenv("HOSTNAME", shortuuid.uuid())
}

def __new__(cls, model_name: str):
def __new__(cls, model_name: Optional[str] = None):
if not hasattr(cls, "instance"):
cls.instance = super(JetstreamMetricsCollector, cls).__new__(cls)
return cls.instance

def __init__(self, model_name: str):
def __init__(self, model_name: Optional[str] = None):
# '-'s are common in model names but invalid in prometheus labels, these are replaced with '_'s
sanitized_model_name=model_name.replace("-", "_")
if sanitized_model_name == "":
print("No model name provided, omitting from metrics labels")
elif not bool(re.match(r'^[a-zA-Z_:][a-zA-Z0-9_:]*$', sanitized_model_name)):
print("Provided model name cannot be used to label prometheus metrics (does not match ^[a-zA-Z_:][a-zA-Z0-9_:]*$), omitting from metrics labels")
else:
self.universal_labels["model_name"]=sanitized_model_name
if model_name is not None:
sanitized_model_name=model_name.replace("-", "_")
if sanitized_model_name == "":
print("No model name provided, omitting from metrics labels")
elif not bool(re.match(r'^[a-zA-Z_:][a-zA-Z0-9_:]*$', sanitized_model_name)):
print("Provided model name cannot be used to label prometheus metrics (does not match ^[a-zA-Z_:][a-zA-Z0-9_:]*$), omitting from metrics labels")
else:
self.universal_labels["model_name"]=sanitized_model_name
universal_label_names = list(self.universal_labels.keys())

# Metric definitions
Expand Down
2 changes: 1 addition & 1 deletion jetstream/entrypoints/http/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def server(argv: Sequence[str]):
if flags.FLAGS.prometheus_port != 0:
metrics_server_config = config_lib.MetricsServerConfig(
port=flags.FLAGS.prometheus_port
model_name="some_model"
model_name=None
)
logging.info(
"Starting Prometheus server on port %d", metrics_server_config.port
Expand Down

0 comments on commit 3cf004c

Please sign in to comment.