diff --git a/src/sparseml/pytorch/utils/logger.py b/src/sparseml/pytorch/utils/logger.py index 0e9a5bc0ff6..5fc5eb3adcb 100644 --- a/src/sparseml/pytorch/utils/logger.py +++ b/src/sparseml/pytorch/utils/logger.py @@ -45,11 +45,21 @@ wandb = None wandb_err = err + +try: + from clearml import Task + + clearml_err = None +except Exception as err: + clearml = None + clearml_err = err + from sparseml.utils import ALL_TOKEN, create_dirs __all__ = [ "BaseLogger", + "ClearMLLogger", "LambdaLogger", "PythonLogger", "TensorBoardLogger", @@ -628,6 +638,98 @@ def save( return True +class ClearMLLogger(LambdaLogger): + @staticmethod + def available() -> bool: + """ + :return: True if wandb is available and installed, False, otherwise + """ + return not clearml_err + + def __init__( + self, + name: str = "clearml", + enabled: bool = True, + project_name: str = "sparseml", + task_name: str = "", + ): + if task_name == "": + now = datetime.now() + task_name = now.strftime("%d-%m-%Y_%H.%M.%S") + + self.task = Task.init(project_name=project_name, task_name=task_name) + + super().__init__( + lambda_func=self.log_scalar, + name=name, + enabled=enabled, + ) + + def log_hyperparams( + self, + params: Dict, + level: Optional[int] = None, + ) -> bool: + """ + :param params: Each key-value pair in the dictionary is the name of the + hyper parameter and it's corresponding value. + :return: True if logged, False otherwise. + """ + if not self.enabled: + return False + + self.task.connect(params) + return True + + def log_scalar( + self, + tag: str, + value: float, + step: Optional[int] = None, + wall_time: Optional[float] = None, + level: Optional[int] = None, + ) -> bool: + """ + :param tag: identifying tag to log the value with + :param value: value to save + :param step: global step for when the value was taken + :param wall_time: global wall time for when the value was taken, + defaults to time.time() + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + logger = self.task.get_logger() + logger.report_single_value(name=tag, value=value) + return True + + def log_scalars( + self, + tag: str, + values: Dict[str, float], + step: Optional[int] = None, + wall_time: Optional[float] = None, + level: Optional[int] = None, + ) -> bool: + """ + :param tag: identifying tag to log the values with + :param values: values to save + :param step: global step for when the values were taken + :param wall_time: global wall time for when the values were taken, + defaults to time.time() + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + for k, v in values.items(): + self.log_scalar( + tag=f"{tag}.{k}", + value=v, + step=step, + wall_time=wall_time, + level=level, + ) + return True + + class SparsificationGroupLogger(BaseLogger): """ Modifier logger that handles outputting values to other supported systems. diff --git a/tests/sparseml/pytorch/utils/test_logger.py b/tests/sparseml/pytorch/utils/test_logger.py index 7cceeff3017..82510aea47a 100644 --- a/tests/sparseml/pytorch/utils/test_logger.py +++ b/tests/sparseml/pytorch/utils/test_logger.py @@ -20,6 +20,7 @@ import pytest from sparseml.pytorch.utils import ( + ClearMLLogger, LambdaLogger, LoggerManager, PythonLogger, @@ -45,6 +46,7 @@ or True ), *([WANDBLogger()] if WANDBLogger.available() else []), + *([ClearMLLogger()] if ClearMLLogger.available() else []), SparsificationGroupLogger( lambda_func=lambda tag, value, values, step, wall_time, level: logging.info( f"{tag}, {value}, {values}, {step}, {wall_time}, {level}" @@ -79,12 +81,12 @@ def test_log_scalar(self, logger): def test_log_scalars(self, logger): logger.log_scalars("test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0}) - logger.log_scalars("test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0}, 1) + logger.log_scalars("test-scalars-tag2", {"scalar1": 0.0, "scalar2": 1.0}, 1) logger.log_scalars( - "test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0}, 2, time.time() - 1 + "test-scalars-tag3", {"scalar1": 0.0, "scalar2": 1.0}, 2, time.time() - 1 ) logger.log_scalars( - "test-scalars-tag", + "test-scalars-tag4", {"scalar1": 0.0, "scalar2": 1.0}, 2, time.time() - 1,