diff --git a/swanlab/data/run/exp.py b/swanlab/data/run/exp.py index 14d17e6e..28d64a95 100644 --- a/swanlab/data/run/exp.py +++ b/swanlab/data/run/exp.py @@ -1,6 +1,6 @@ from swanlab.data.modules import DataWrapper, Line from swanlab.log import swanlog -from typing import Dict, Optional +from typing import Dict, Optional, Callable, Any from swankit.env import create_time from swankit.callback import MetricInfo, ColumnInfo from .helper import SwanLabRunOperator @@ -31,8 +31,8 @@ def __init__(self, settings: SwanLabSharedSettings, operator: SwanLabRunOperator # TODO 操作员不传递给实验 self.__operator = operator - def add(self, key: str, data: DataWrapper, step: int = None) -> MetricInfo: - """记录一条新的tag数据 + def __add(self, key: str, data: DataWrapper, step: int = None) -> MetricInfo: + """记录一条新的key数据 Parameters ---------- @@ -85,6 +85,23 @@ def add(self, key: str, data: DataWrapper, step: int = None) -> MetricInfo: key_info.media_dir = self.settings.media_dir return key_info + def add(self, key: str, data: DataWrapper, step: int = None) -> MetricInfo: + """记录一条新的key数据 + Parameters + ---------- + key : str + key名称 + data : DataWrapper + 包装后的数据,用于数据解析 + + step : int, optional + 步数,如果不传则默认当前步数为'已添加数据数量+1' + 在log函数中已经做了处理,此处不需要考虑数值类型等情况 + """ + m = self.__add(key, data, step) + self.__operator.on_metric_create(m) + return m + def warn_type_error(self, key: str): """警告类型错误 执行此方法时需保证key已经存在 diff --git a/swanlab/data/run/main.py b/swanlab/data/run/main.py index b6792c90..8f48e1df 100644 --- a/swanlab/data/run/main.py +++ b/swanlab/data/run/main.py @@ -323,7 +323,6 @@ def log(self, data: dict, step: int = None): # 数据类型的检查将在创建chart配置的时候完成,因为数据类型错误并不会影响实验进行 metric_info = self.__exp.add(key=k, data=v, step=step) print(f'Corr exp = {self.__exp}') - self.__operator.on_metric_create(metric_info) log_return[metric_info.key] = metric_info return log_return