diff --git a/swanlab/sync/wandb.py b/swanlab/sync/wandb.py index 049372a3..bb496e0d 100644 --- a/swanlab/sync/wandb.py +++ b/swanlab/sync/wandb.py @@ -1,23 +1,33 @@ import swanlab -def sync_wandb(): +def sync_wandb(mode:str="cloud", wandb_run:bool=True): """ - sync wandb with swanlab + sync wandb with swanlab, 暂时不支持log非标量类型 + + - mode: "cloud", "local" or "disabled". https://docs.swanlab.cn/api/py-init.html + - wandb_run: 如果此参数设置为False,则不会将数据上传到wandb,等同于设置wandb.init(mode="offline")。 usecase: ```python + import wandb + import random import swanlab + swanlab.sync_wandb() - + # swanlab.init(project="sync_wandb") + wandb.init( project="test", config={"a": 1, "b": 2}, name="test", ) - for epoch in range(10): + epochs = 10 + offset = random.random() / 5 + for epoch in range(2, epochs): acc = 1 - 2 ** -epoch - random.random() / epoch - offset loss = 2 ** -epoch + random.random() / epoch + offset + wandb.log({"acc": acc, "loss": loss}) ``` """ @@ -30,6 +40,7 @@ def sync_wandb(): original_init = wandb.init original_log = wandb_sdk.wandb_run.Run.log original_finish = wandb_sdk.finish + original_config_update = wandb_sdk.wandb_config.Config.update def patched_init(*args, **kwargs): project = kwargs.get('project', None) @@ -40,11 +51,25 @@ def patched_init(*args, **kwargs): swanlab.init( project=project, experiment_name=name, - config=config) + config=config, + mode=mode) else: swanlab.config.update(config) - return original_init(*args, **kwargs) + if wandb_run is False: + kwargs["mode"] = "offline" + return original_init(*args, **kwargs) + else: + return original_init(*args, **kwargs) + + def patched_config_update(*args, **kwargs): + try: + config= args[1] + except: + config= None + if config is not None: + swanlab.config.update(config) + return original_config_update(*args, **kwargs) def patched_log(*args, **kwargs): data = args[1] @@ -66,4 +91,5 @@ def patched_finish(*args, **kwargs): wandb.init = patched_init wandb_sdk.wandb_run.Run.log = patched_log - wandb_sdk.wandb_run.Run.finish = patched_finish \ No newline at end of file + wandb_sdk.wandb_run.Run.finish = patched_finish + wandb_sdk.wandb_config.Config.update = patched_config_update \ No newline at end of file