From ffaaf9749cf7ab0c3263c6dad4b3f640c8a651d7 Mon Sep 17 00:00:00 2001 From: Ze-Yi LIN <58305964+Zeyi-Lin@users.noreply.github.com> Date: Fri, 17 Jan 2025 13:52:08 +0800 Subject: [PATCH] feat: sync wandb (#785) Co-authored-by: KAAANG <79990647+SAKURA-CAT@users.noreply.github.com> --- swanlab/__init__.py | 2 ++ swanlab/sync/__init__.py | 3 ++ swanlab/sync/wandb.py | 69 ++++++++++++++++++++++++++++++++++++++++ test/wandb_sync.py | 19 +++++++++++ 4 files changed, 93 insertions(+) create mode 100644 swanlab/sync/__init__.py create mode 100644 swanlab/sync/wandb.py create mode 100644 test/wandb_sync.py diff --git a/swanlab/__init__.py b/swanlab/__init__.py index 3cdaea31a..d5cb43339 100755 --- a/swanlab/__init__.py +++ b/swanlab/__init__.py @@ -16,6 +16,7 @@ from .data.run.main import config from .package import get_package_version from .env import SwanLabEnv +from .sync import sync_wandb # 设置默认环境变量 SwanLabEnv.set_default() @@ -38,4 +39,5 @@ "get_config", "config", "__version__", + "sync_wandb", ] diff --git a/swanlab/sync/__init__.py b/swanlab/sync/__init__.py new file mode 100644 index 000000000..481f91b98 --- /dev/null +++ b/swanlab/sync/__init__.py @@ -0,0 +1,3 @@ +from .wandb import sync_wandb + +__all__ = ["sync_wandb"] diff --git a/swanlab/sync/wandb.py b/swanlab/sync/wandb.py new file mode 100644 index 000000000..049372a37 --- /dev/null +++ b/swanlab/sync/wandb.py @@ -0,0 +1,69 @@ +import swanlab + +def sync_wandb(): + """ + sync wandb with swanlab + + usecase: + ```python + import swanlab + swanlab.sync_wandb() + + wandb.init( + project="test", + config={"a": 1, "b": 2}, + name="test", + ) + + for epoch in range(10): + acc = 1 - 2 ** -epoch - random.random() / epoch - offset + loss = 2 ** -epoch + random.random() / epoch + offset + wandb.log({"acc": acc, "loss": loss}) + ``` + """ + try: + import wandb + from wandb import sdk as wandb_sdk + except ImportError: + raise ImportError("please install wandb first, command: `pip install wandb`") + + original_init = wandb.init + original_log = wandb_sdk.wandb_run.Run.log + original_finish = wandb_sdk.finish + + def patched_init(*args, **kwargs): + project = kwargs.get('project', None) + name = kwargs.get('name', None) + config = kwargs.get('config', None) + + if swanlab.data.get_run() is None: + swanlab.init( + project=project, + experiment_name=name, + config=config) + else: + swanlab.config.update(config) + + return original_init(*args, **kwargs) + + def patched_log(*args, **kwargs): + data = args[1] + step = kwargs.get('step', None) + + # 过滤掉非标量类型 + filtered_data = {} + for key, value in data.items(): + if isinstance(value, (int, float, bool, str)): + filtered_data[key] = value + + swanlab.log(data=filtered_data, step=step) + + return original_log(*args, **kwargs) + + def patched_finish(*args, **kwargs): + swanlab.finish() + return original_finish(*args, **kwargs) + + wandb.init = patched_init + wandb_sdk.wandb_run.Run.log = patched_log + wandb_sdk.wandb_run.Run.finish = patched_finish \ No newline at end of file diff --git a/test/wandb_sync.py b/test/wandb_sync.py new file mode 100644 index 000000000..4e9f6eada --- /dev/null +++ b/test/wandb_sync.py @@ -0,0 +1,19 @@ +import wandb +import random +import swanlab + +swanlab.sync_wandb() + +wandb.init( + project="test", + config={"a": 1, "b": 2}, + name="test", + ) + +epochs = 10 +offset = random.random() / 5 +for epoch in range(2, epochs): + acc = 1 - 2 ** -epoch - random.random() / epoch - offset + loss = 2 ** -epoch + random.random() / epoch + offset + + wandb.log({"acc": acc, "loss": loss}) \ No newline at end of file