diff --git a/swankit/callback/models/key.py b/swankit/callback/models/key.py index 0805ed3..01be7cd 100644 --- a/swankit/callback/models/key.py +++ b/swankit/callback/models/key.py @@ -93,21 +93,20 @@ class MetricInfo: def __init__( self, - error: Optional[ParseErrorInfo], column_info: ColumnInfo, - metric: Union[Dict, None] = None, - metric_buffers: List[MediaBuffer] = None, - metric_summary: Union[Dict, None] = None, - metric_step: int = None, - metric_epoch: int = None, - metric_file_name: str = None, - swanlab_logdir: str = None, - swanlab_media_dir: str = None, + metric: Optional[Dict], + metric_buffers: Optional[List[MediaBuffer]], + metric_summary: Optional[Dict], + metric_step: Optional[int], + metric_epoch: Optional[int], + metric_file_name: Optional[str], + swanlab_logdir: Optional[str], + swanlab_media_dir: Optional[str], + error: Optional[ParseErrorInfo] = None, ): """ 生成的指标信息对象 :param column_info: 此指标对应的列信息 - :param error: 创建此指标时的错误信息 :param metric: 此指标的数据 :param metric_buffers: 此指标的媒体数据,如果为None,表示没有媒体数据 :param metric_summary: 此指标的摘要信息 @@ -116,6 +115,7 @@ def __init__( :param metric_file_name: 此指标的文件名 :param swanlab_logdir: swanlab在本次实验的log文件夹路径 :param swanlab_media_dir: swanlab在本次实验的media文件夹路径 + :param error: 创建此指标时的错误信息 """ self.error = error self.column_info = column_info @@ -159,3 +159,24 @@ def data(self) -> Union[Dict, None]: if self.is_error: return None return self.metric["data"] + + +class ErrorMetricInfo(MetricInfo): + def __init__(self, column_info: ColumnInfo, error: ParseErrorInfo): + """ + 错误的指标信息,简化输入参数 + :param column_info: 此指标对应的列信息 + :param error: 创建此指标时的错误信息 + """ + super().__init__( + column_info, + None, + None, + None, + None, + None, + None, + None, + None, + error, + )