Skip to content

Commit

Permalink
include in loggers
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Mar 26, 2024
1 parent 4b7097c commit 91246fe
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 3 deletions.
102 changes: 102 additions & 0 deletions src/sparseml/pytorch/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,21 @@
wandb = None
wandb_err = err


try:
from clearml import Task

clearml_err = None
except Exception as err:
clearml = None
clearml_err = err

from sparseml.utils import ALL_TOKEN, create_dirs


__all__ = [
"BaseLogger",
"ClearMLLogger",
"LambdaLogger",
"PythonLogger",
"TensorBoardLogger",
Expand Down Expand Up @@ -628,6 +638,98 @@ def save(
return True


class ClearMLLogger(LambdaLogger):
@staticmethod
def available() -> bool:
"""
:return: True if wandb is available and installed, False, otherwise
"""
return not clearml_err

def __init__(
self,
name: str = "clearml",
enabled: bool = True,
project_name: str = "sparseml",
task_name: str = "",
):
if task_name == "":
now = datetime.now()
task_name = now.strftime("%d-%m-%Y_%H.%M.%S")

self.task = Task.init(project_name=project_name, task_name=task_name)

super().__init__(
lambda_func=self.log_scalar,
name=name,
enabled=enabled,
)

def log_hyperparams(
self,
params: Dict,
level: Optional[int] = None,
) -> bool:
"""
:param params: Each key-value pair in the dictionary is the name of the
hyper parameter and it's corresponding value.
:return: True if logged, False otherwise.
"""
if not self.enabled:
return False

self.task.connect(params)
return True

def log_scalar(
self,
tag: str,
value: float,
step: Optional[int] = None,
wall_time: Optional[float] = None,
level: Optional[int] = None,
) -> bool:
"""
:param tag: identifying tag to log the value with
:param value: value to save
:param step: global step for when the value was taken
:param wall_time: global wall time for when the value was taken,
defaults to time.time()
:param kwargs: additional logging arguments to support Python and custom loggers
:return: True if logged, False otherwise.
"""
logger = self.task.get_logger()
logger.report_single_value(name=tag, value=value)
return True

def log_scalars(
self,
tag: str,
values: Dict[str, float],
step: Optional[int] = None,
wall_time: Optional[float] = None,
level: Optional[int] = None,
) -> bool:
"""
:param tag: identifying tag to log the values with
:param values: values to save
:param step: global step for when the values were taken
:param wall_time: global wall time for when the values were taken,
defaults to time.time()
:param kwargs: additional logging arguments to support Python and custom loggers
:return: True if logged, False otherwise.
"""
for k, v in values.items():
self.log_scalar(
tag=f"{tag}.{k}",
value=v,
step=step,
wall_time=wall_time,
level=level,
)
return True


class SparsificationGroupLogger(BaseLogger):
"""
Modifier logger that handles outputting values to other supported systems.
Expand Down
8 changes: 5 additions & 3 deletions tests/sparseml/pytorch/utils/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest

from sparseml.pytorch.utils import (
ClearMLLogger,
LambdaLogger,
LoggerManager,
PythonLogger,
Expand All @@ -45,6 +46,7 @@
or True
),
*([WANDBLogger()] if WANDBLogger.available() else []),
*([ClearMLLogger()] if ClearMLLogger.available() else []),
SparsificationGroupLogger(
lambda_func=lambda tag, value, values, step, wall_time, level: logging.info(
f"{tag}, {value}, {values}, {step}, {wall_time}, {level}"
Expand Down Expand Up @@ -79,12 +81,12 @@ def test_log_scalar(self, logger):

def test_log_scalars(self, logger):
logger.log_scalars("test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0})
logger.log_scalars("test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0}, 1)
logger.log_scalars("test-scalars-tag2", {"scalar1": 0.0, "scalar2": 1.0}, 1)
logger.log_scalars(
"test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0}, 2, time.time() - 1
"test-scalars-tag3", {"scalar1": 0.0, "scalar2": 1.0}, 2, time.time() - 1
)
logger.log_scalars(
"test-scalars-tag",
"test-scalars-tag4",
{"scalar1": 0.0, "scalar2": 1.0},
2,
time.time() - 1,
Expand Down

0 comments on commit 91246fe

Please sign in to comment.