Skip to content

Commit

Permalink
bugfix/config_json_serializable (#671)
Browse files Browse the repository at this point in the history
* fix swanlab config
  • Loading branch information
Zeyi-Lin authored Aug 13, 2024
1 parent 36701be commit 62d15f9
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 59 deletions.
9 changes: 8 additions & 1 deletion swanlab/data/run/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def json_serializable(obj):
# 对于可变映射,递归调用此函数处理值,并将key转换为字典
elif isinstance(obj, MutableMapping):
return {str(key): json_serializable(value) for key, value in obj.items()}
raise TypeError(f"Object {obj} is not JSON serializable")

try:
return str(obj)
except Exception:
raise TypeError(f"Object: {obj} is not JSON serializable")


def third_party_config_process(data) -> dict:
Expand Down Expand Up @@ -84,16 +88,19 @@ def parse(config) -> dict:
"""
if config is None:
return {}

# 1. 第三方配置类型判断与转换
try:
return third_party_config_process(config)
except TypeError:
pass

# 2. 将config转换为可被json序列化的字典
try:
return json_serializable(config)
except TypeError: # noqa
pass

# 3. 尝试序列化,序列化成功直接返回
try:
return json.loads(json.dumps(config))
Expand Down
52 changes: 29 additions & 23 deletions swanlab/data/run/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ class SwanLabRun:
"""

def __init__(
self,
project_name: str = None,
experiment_name: str = None,
description: str = None,
run_config=None,
log_level: str = None,
suffix: str = None,
exp_num: int = None,
operator: SwanLabRunOperator = SwanLabRunOperator(),
self,
project_name: str = None,
experiment_name: str = None,
description: str = None,
run_config=None,
log_level: str = None,
suffix: str = None,
exp_num: int = None,
operator: SwanLabRunOperator = SwanLabRunOperator(),
):
"""
Initializing the SwanLabRun class involves configuring the settings and initiating other logging processes.
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(
# ---------------------------------- 初始化类内参数 ----------------------------------
self.__project_name = project_name
# 生成一个唯一的id,随机生成一个8位的16进制字符串,小写
_id = hex(random.randint(0, 2 ** 32 - 1))[2:].zfill(8)
_id = hex(random.randint(0, 2**32 - 1))[2:].zfill(8)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.__run_id = "run-{}-{}".format(timestamp, _id)
# 操作员初始化
Expand All @@ -101,6 +101,11 @@ def __init__(
# ---------------------------------- 初始化日志记录器 ----------------------------------
swanlog.level = self.__check_log_level(log_level)
# ---------------------------------- 初始化配置 ----------------------------------
# 如果config是以下几个类别之一,则抛出异常
if isinstance(run_config, (int, float, str, bool, list, tuple, set)):
raise TypeError(
f"config: {run_config} (type: {type(run_config)}) is not a json serialized dict (Surpport type is dict, MutableMapping, omegaconf.DictConfig, Argparse.Namespace), please check it"
)
global config
config.update(run_config)
setattr(config, "_SwanLabConfig__on_setter", self.__operator.on_runtime_info_update)
Expand All @@ -123,10 +128,11 @@ def _(state: SwanLabRunState):
# 执行__save,必须在on_run之后,因为on_run之前部分的信息还没完全初始化
getattr(config, "_SwanLabConfig__save")()
# 系统信息采集
self.__operator.on_runtime_info_update(RuntimeInfo(
requirements=get_requirements(),
metadata=get_system_info(get_package_version(), self.settings.log_dir)
))
self.__operator.on_runtime_info_update(
RuntimeInfo(
requirements=get_requirements(), metadata=get_system_info(get_package_version(), self.settings.log_dir)
)
)

@property
def operator(self) -> SwanLabRunOperator:
Expand Down Expand Up @@ -302,10 +308,10 @@ def log(self, data: dict, step: int = None):
v = DataWrapper(k, [v])
# 为List[MediaType]或者List[Line]类型,且长度大于0,且所有元素类型相同
elif (
isinstance(v, list)
and len(v) > 0
and all([isinstance(i, (Line, MediaType)) for i in v])
and all([i.__class__ == v[0].__class__ for i in v])
isinstance(v, list)
and len(v) > 0
and all([isinstance(i, (Line, MediaType)) for i in v])
and all([i.__class__ == v[0].__class__ for i in v])
):
v = DataWrapper(k, v)
else:
Expand All @@ -323,11 +329,11 @@ def __str__(self) -> str:
return self.__run_id

def __register_exp(
self,
experiment_name: str,
description: str = None,
suffix: str = None,
num: int = None,
self,
experiment_name: str,
description: str = None,
suffix: str = None,
num: int = None,
) -> SwanLabExp:
"""
注册实验,将实验配置写入数据库中,完成实验配置的初始化
Expand Down
22 changes: 11 additions & 11 deletions swanlab/data/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ def login(api_key: str = None):


def init(
project: str = None,
workspace: str = None,
experiment_name: str = None,
description: str = None,
config: Union[dict, str] = None,
logdir: str = None,
suffix: Union[str, None, bool] = "default",
mode: Literal["disabled", "cloud", "local"] = None,
load: str = None,
public: bool = None,
**kwargs,
project: str = None,
workspace: str = None,
experiment_name: str = None,
description: str = None,
config: Union[dict, str] = None,
logdir: str = None,
suffix: Union[str, None, bool] = "default",
mode: Literal["disabled", "cloud", "local"] = None,
load: str = None,
public: bool = None,
**kwargs,
) -> SwanLabRun:
"""
Start a new run to track and log. Once you have called this function, you can use 'swanlab.log' to log data to
Expand Down
76 changes: 52 additions & 24 deletions test/unit/data/run/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __iter__(self):
config = parse(config_data)
assert config["inf"] == Line.inf
assert config["nan"] == Line.nan

# ---------------------------------- dataclass support ----------------------------------
@dataclass
class MyData:
Expand Down Expand Up @@ -294,12 +295,14 @@ def test_use_dict(self):
"""
正常流程,输入字典
"""
run = SwanLabRun(run_config={
"a": 1,
"b": "mnist",
"c/d": [1, 2, 3],
"e/f/h": {"a": 1, "b": {"c": 2}},
})
run = SwanLabRun(
run_config={
"a": 1,
"b": "mnist",
"c/d": [1, 2, 3],
"e/f/h": {"a": 1, "b": {"c": 2}},
}
)
config = run.config
_config = get_config()
assert config["a"] == _config["a"] == 1
Expand All @@ -310,12 +313,16 @@ def test_use_omegaconf(self):
"""
正常流程,输入OmegaConf
"""
run = SwanLabRun(run_config=omegaconf.OmegaConf.create({
"a": 1,
"b": "mnist",
"c/d": [1, 2, 3],
"e/f/h": {"a": 1, "b": {"c": 2}},
}))
run = SwanLabRun(
run_config=omegaconf.OmegaConf.create(
{
"a": 1,
"b": "mnist",
"c/d": [1, 2, 3],
"e/f/h": {"a": 1, "b": {"c": 2}},
}
)
)
config = run.config
_config = get_config()
assert config["a"] == _config["a"] == 1
Expand All @@ -337,12 +344,16 @@ def test_use_config(self):
"""
正常流程,输入SwanLabConfig
"""
run = SwanLabRun(run_config=SwanLabConfig({
"a": 1,
"b": "mnist",
"c": [1, 2, 3],
"e/f/h": {"a": 1, "b": {"c": 2}},
}))
run = SwanLabRun(
run_config=SwanLabConfig(
{
"a": 1,
"b": "mnist",
"c": [1, 2, 3],
"e/f/h": {"a": 1, "b": {"c": 2}},
}
)
)
config = run.config
_config = get_config()
assert config["a"] == _config["a"] == 1
Expand All @@ -353,14 +364,31 @@ def test_after_finish(self):
"""
测试在finish之后config的变化
"""
run = SwanLabRun(run_config={
"a": 1,
"b": "mnist",
"c/d": [1, 2, 3],
"e/f/h": {"a": 1, "b": {"c": 2}},
})
run = SwanLabRun(
run_config={
"a": 1,
"b": "mnist",
"c/d": [1, 2, 3],
"e/f/h": {"a": 1, "b": {"c": 2}},
}
)
run.finish()
config = run.config
_config = get_config()
assert len(config) == 4
assert len(_config) == 0

def test_error_config_input(self):
"""
测试错误的输入
"""
with pytest.raises(TypeError):
SwanLabRun(run_config=1)
with pytest.raises(TypeError):
SwanLabRun(run_config="1")
with pytest.raises(TypeError):
SwanLabRun(run_config=[1, 2, 3])
with pytest.raises(TypeError):
SwanLabRun(run_config=(1, 2, 3))
with pytest.raises(TypeError):
SwanLabRun(run_config=True)

0 comments on commit 62d15f9

Please sign in to comment.