Skip to content

Commit

Permalink
add:test
Browse files Browse the repository at this point in the history
  • Loading branch information
SAKURA-CAT committed Oct 20, 2024
1 parent 0369ae3 commit a5a20de
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 9 deletions.
17 changes: 8 additions & 9 deletions swankit/callback/models/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def __init__(
key_id: str,
key_name: str,
key_class: Literal["CUSTOM", "SYSTEM"],
section: str,
section_name: str,
section_sort: int,
chart: ChartType,
chart_type: ChartType,
chart_reference: Literal["step", "time"],
error: Optional[ParseErrorInfo] = None,
config: Optional[Dict] = None,
Expand All @@ -37,9 +37,9 @@ def __init__(
:param key_id: 当前实验下,列的唯一id,与保存路径等信息有关
:param key_name: key的别名
:param key_class: 列的类型,CUSTOM为自定义列,SYSTEM为系统生成列
:param section: 列的组
:param section_name: 列的组名
:param section_sort: 列在section中的参考排序,不代表实际排序
:param chart: 列对应的图表类型
:param chart_type: 列对应的图表类型
:param chart_reference: 这个列对应图表的参考系,step为步数,time为时间
:param error: 列的类型错误信息
:param config: 列的额外配置信息
Expand All @@ -49,10 +49,10 @@ def __init__(
self.key_name = key_name
self.key_class = key_class

self.section = section
self.section_name = section_name
self.section_sort = section_sort

self.chart = chart
self.chart_type = chart_type
self.chart_reference = chart_reference

self.error = error
Expand Down Expand Up @@ -125,10 +125,9 @@ def __init__(
self.metric_step = metric_step
self.metric_epoch = metric_epoch
_id = self.column_info.key_id
self.metric_path = None if self.is_error else os.path.join(swanlab_logdir, _id, metric_file_name)
self.summary_path = None if self.is_error else os.path.join(swanlab_logdir, _id, self.__SUMMARY_NAME)
self.metric_file_path = None if self.is_error else os.path.join(swanlab_logdir, _id, metric_file_name)
self.summary_file_path = None if self.is_error else os.path.join(swanlab_logdir, _id, self.__SUMMARY_NAME)
self.swanlab_media_dir = swanlab_media_dir
self.metric_buffers = metric_buffers
# 写入文件名称,对应上传时的文件名称:{key}/{文件名称},文件夹名称为key
if self.metric_buffers is not None:
for i, buffer in enumerate(self.metric_buffers):
Expand Down
80 changes: 80 additions & 0 deletions test/unit/callback/models/test_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from swankit.callback.models import key as K
from swankit.core import ChartType


def test_column_info():
c = K.ColumnInfo(
key="a/1",
key_id="b",
key_name="c",
key_class="SYSTEM",
section_name="e",
section_sort=1,
chart_type=ChartType.TEXT,
chart_reference="step",
error=None,
config=None,
)
assert c.got is None
assert c.key == "a/1"
assert c.key_id == "b"
assert c.key_name == "c"
assert c.key_class == "SYSTEM"
assert c.section_name == "e"
assert c.section_sort == 1
assert c.chart_type == ChartType.TEXT
assert c.chart_reference == "step"
assert c.error is None
assert c.config == {}
assert c.key_encode == "a%2F1"


def test_metric_info():
c = K.ColumnInfo(
key="a/1",
key_id="b",
key_name="c",
key_class="SYSTEM",
section_name="e",
section_sort=1,
chart_type=ChartType.TEXT,
chart_reference="step",
error=None,
config=None,
)

m = K.MetricInfo(
column_info=c,
metric={"data": 1},
metric_buffers=None,
metric_summary={"data": 1},
metric_file_name="1.log",
metric_step=1,
metric_epoch=1,
swanlab_logdir=".",
swanlab_media_dir=".",
)
assert m.column_info.got is None
assert m.column_info.key == "a/1"
assert m.column_info.key_id == "b"
assert m.column_info.key_name == "c"
assert m.column_info.key_class == "SYSTEM"
assert m.column_info.section_name == "e"
assert m.column_info.section_sort == 1
assert m.column_info.chart_type == ChartType.TEXT
assert m.column_info.chart_reference == "step"
assert m.column_info.error is None
assert m.column_info.config == {}
assert m.column_info.key_encode == "a%2F1"
assert m.column_info.got is None
assert m.column_info.expected is None
assert m.column_info.key_encode == "a%2F1"
assert m.column_info.key == "a/1"
assert m.column_info.key_id == "b"
assert m.metric == {"data": 1}
assert m.metric_buffers is None
assert m.metric_summary == {"data": 1}
assert m.metric_step == 1
assert m.metric_epoch == 1
assert m.swanlab_media_dir == "."
assert m.metric_file_path == f"./{c.key_id}/1.log"

0 comments on commit a5a20de

Please sign in to comment.