From 62d15f96cba0d7cfa529b9727c57a6b438e88c37 Mon Sep 17 00:00:00 2001 From: Ze-Yi LIN <58305964+Zeyi-Lin@users.noreply.github.com> Date: Tue, 13 Aug 2024 18:02:01 +0800 Subject: [PATCH] bugfix/config_json_serializable (#671) * fix swanlab config --- swanlab/data/run/config.py | 9 +++- swanlab/data/run/main.py | 52 +++++++++++---------- swanlab/data/sdk.py | 22 ++++----- test/unit/data/run/test_config.py | 76 +++++++++++++++++++++---------- 4 files changed, 100 insertions(+), 59 deletions(-) diff --git a/swanlab/data/run/config.py b/swanlab/data/run/config.py index 772781fb1..9ca171d67 100644 --- a/swanlab/data/run/config.py +++ b/swanlab/data/run/config.py @@ -50,7 +50,11 @@ def json_serializable(obj): # 对于可变映射,递归调用此函数处理值,并将key转换为字典 elif isinstance(obj, MutableMapping): return {str(key): json_serializable(value) for key, value in obj.items()} - raise TypeError(f"Object {obj} is not JSON serializable") + + try: + return str(obj) + except Exception: + raise TypeError(f"Object: {obj} is not JSON serializable") def third_party_config_process(data) -> dict: @@ -84,16 +88,19 @@ def parse(config) -> dict: """ if config is None: return {} + # 1. 第三方配置类型判断与转换 try: return third_party_config_process(config) except TypeError: pass + # 2. 将config转换为可被json序列化的字典 try: return json_serializable(config) except TypeError: # noqa pass + # 3. 尝试序列化,序列化成功直接返回 try: return json.loads(json.dumps(config)) diff --git a/swanlab/data/run/main.py b/swanlab/data/run/main.py index 9a91e82f4..5610228a6 100644 --- a/swanlab/data/run/main.py +++ b/swanlab/data/run/main.py @@ -42,15 +42,15 @@ class SwanLabRun: """ def __init__( - self, - project_name: str = None, - experiment_name: str = None, - description: str = None, - run_config=None, - log_level: str = None, - suffix: str = None, - exp_num: int = None, - operator: SwanLabRunOperator = SwanLabRunOperator(), + self, + project_name: str = None, + experiment_name: str = None, + description: str = None, + run_config=None, + log_level: str = None, + suffix: str = None, + exp_num: int = None, + operator: SwanLabRunOperator = SwanLabRunOperator(), ): """ Initializing the SwanLabRun class involves configuring the settings and initiating other logging processes. @@ -86,7 +86,7 @@ def __init__( # ---------------------------------- 初始化类内参数 ---------------------------------- self.__project_name = project_name # 生成一个唯一的id,随机生成一个8位的16进制字符串,小写 - _id = hex(random.randint(0, 2 ** 32 - 1))[2:].zfill(8) + _id = hex(random.randint(0, 2**32 - 1))[2:].zfill(8) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.__run_id = "run-{}-{}".format(timestamp, _id) # 操作员初始化 @@ -101,6 +101,11 @@ def __init__( # ---------------------------------- 初始化日志记录器 ---------------------------------- swanlog.level = self.__check_log_level(log_level) # ---------------------------------- 初始化配置 ---------------------------------- + # 如果config是以下几个类别之一,则抛出异常 + if isinstance(run_config, (int, float, str, bool, list, tuple, set)): + raise TypeError( + f"config: {run_config} (type: {type(run_config)}) is not a json serialized dict (Surpport type is dict, MutableMapping, omegaconf.DictConfig, Argparse.Namespace), please check it" + ) global config config.update(run_config) setattr(config, "_SwanLabConfig__on_setter", self.__operator.on_runtime_info_update) @@ -123,10 +128,11 @@ def _(state: SwanLabRunState): # 执行__save,必须在on_run之后,因为on_run之前部分的信息还没完全初始化 getattr(config, "_SwanLabConfig__save")() # 系统信息采集 - self.__operator.on_runtime_info_update(RuntimeInfo( - requirements=get_requirements(), - metadata=get_system_info(get_package_version(), self.settings.log_dir) - )) + self.__operator.on_runtime_info_update( + RuntimeInfo( + requirements=get_requirements(), metadata=get_system_info(get_package_version(), self.settings.log_dir) + ) + ) @property def operator(self) -> SwanLabRunOperator: @@ -302,10 +308,10 @@ def log(self, data: dict, step: int = None): v = DataWrapper(k, [v]) # 为List[MediaType]或者List[Line]类型,且长度大于0,且所有元素类型相同 elif ( - isinstance(v, list) - and len(v) > 0 - and all([isinstance(i, (Line, MediaType)) for i in v]) - and all([i.__class__ == v[0].__class__ for i in v]) + isinstance(v, list) + and len(v) > 0 + and all([isinstance(i, (Line, MediaType)) for i in v]) + and all([i.__class__ == v[0].__class__ for i in v]) ): v = DataWrapper(k, v) else: @@ -323,11 +329,11 @@ def __str__(self) -> str: return self.__run_id def __register_exp( - self, - experiment_name: str, - description: str = None, - suffix: str = None, - num: int = None, + self, + experiment_name: str, + description: str = None, + suffix: str = None, + num: int = None, ) -> SwanLabExp: """ 注册实验,将实验配置写入数据库中,完成实验配置的初始化 diff --git a/swanlab/data/sdk.py b/swanlab/data/sdk.py index e0a9a2a24..c632e49e5 100644 --- a/swanlab/data/sdk.py +++ b/swanlab/data/sdk.py @@ -70,17 +70,17 @@ def login(api_key: str = None): def init( - project: str = None, - workspace: str = None, - experiment_name: str = None, - description: str = None, - config: Union[dict, str] = None, - logdir: str = None, - suffix: Union[str, None, bool] = "default", - mode: Literal["disabled", "cloud", "local"] = None, - load: str = None, - public: bool = None, - **kwargs, + project: str = None, + workspace: str = None, + experiment_name: str = None, + description: str = None, + config: Union[dict, str] = None, + logdir: str = None, + suffix: Union[str, None, bool] = "default", + mode: Literal["disabled", "cloud", "local"] = None, + load: str = None, + public: bool = None, + **kwargs, ) -> SwanLabRun: """ Start a new run to track and log. Once you have called this function, you can use 'swanlab.log' to log data to diff --git a/test/unit/data/run/test_config.py b/test/unit/data/run/test_config.py index 40051de90..2b2159d34 100644 --- a/test/unit/data/run/test_config.py +++ b/test/unit/data/run/test_config.py @@ -68,6 +68,7 @@ def __iter__(self): config = parse(config_data) assert config["inf"] == Line.inf assert config["nan"] == Line.nan + # ---------------------------------- dataclass support ---------------------------------- @dataclass class MyData: @@ -294,12 +295,14 @@ def test_use_dict(self): """ 正常流程,输入字典 """ - run = SwanLabRun(run_config={ - "a": 1, - "b": "mnist", - "c/d": [1, 2, 3], - "e/f/h": {"a": 1, "b": {"c": 2}}, - }) + run = SwanLabRun( + run_config={ + "a": 1, + "b": "mnist", + "c/d": [1, 2, 3], + "e/f/h": {"a": 1, "b": {"c": 2}}, + } + ) config = run.config _config = get_config() assert config["a"] == _config["a"] == 1 @@ -310,12 +313,16 @@ def test_use_omegaconf(self): """ 正常流程,输入OmegaConf """ - run = SwanLabRun(run_config=omegaconf.OmegaConf.create({ - "a": 1, - "b": "mnist", - "c/d": [1, 2, 3], - "e/f/h": {"a": 1, "b": {"c": 2}}, - })) + run = SwanLabRun( + run_config=omegaconf.OmegaConf.create( + { + "a": 1, + "b": "mnist", + "c/d": [1, 2, 3], + "e/f/h": {"a": 1, "b": {"c": 2}}, + } + ) + ) config = run.config _config = get_config() assert config["a"] == _config["a"] == 1 @@ -337,12 +344,16 @@ def test_use_config(self): """ 正常流程,输入SwanLabConfig """ - run = SwanLabRun(run_config=SwanLabConfig({ - "a": 1, - "b": "mnist", - "c": [1, 2, 3], - "e/f/h": {"a": 1, "b": {"c": 2}}, - })) + run = SwanLabRun( + run_config=SwanLabConfig( + { + "a": 1, + "b": "mnist", + "c": [1, 2, 3], + "e/f/h": {"a": 1, "b": {"c": 2}}, + } + ) + ) config = run.config _config = get_config() assert config["a"] == _config["a"] == 1 @@ -353,14 +364,31 @@ def test_after_finish(self): """ 测试在finish之后config的变化 """ - run = SwanLabRun(run_config={ - "a": 1, - "b": "mnist", - "c/d": [1, 2, 3], - "e/f/h": {"a": 1, "b": {"c": 2}}, - }) + run = SwanLabRun( + run_config={ + "a": 1, + "b": "mnist", + "c/d": [1, 2, 3], + "e/f/h": {"a": 1, "b": {"c": 2}}, + } + ) run.finish() config = run.config _config = get_config() assert len(config) == 4 assert len(_config) == 0 + + def test_error_config_input(self): + """ + 测试错误的输入 + """ + with pytest.raises(TypeError): + SwanLabRun(run_config=1) + with pytest.raises(TypeError): + SwanLabRun(run_config="1") + with pytest.raises(TypeError): + SwanLabRun(run_config=[1, 2, 3]) + with pytest.raises(TypeError): + SwanLabRun(run_config=(1, 2, 3)) + with pytest.raises(TypeError): + SwanLabRun(run_config=True)