From 14b4796d366b6590077e805addd51fa352119069 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Thu, 16 Jan 2025 22:42:23 +0800 Subject: [PATCH 1/6] integration wandb --- swanlab/integration/wandb.py | 73 ++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 swanlab/integration/wandb.py diff --git a/swanlab/integration/wandb.py b/swanlab/integration/wandb.py new file mode 100644 index 00000000..f5e6421c --- /dev/null +++ b/swanlab/integration/wandb.py @@ -0,0 +1,73 @@ +""" +swanlab.init(sync_wandb=True) +""" +import swanlab +try: + import wandb + from wandb import sdk as wandb_sdk +except ImportError: + raise ImportError("please install wandb first, command: `pip install wandb`") + +def sync_wandb_patches(): + original_init = wandb.init + original_log = wandb_sdk.wandb_run.Run.log + original_finish = wandb_sdk.wandb_run.Run.finish + + def patched_init(*args, **kwargs): + project = kwargs.get('project', None) + name = kwargs.get('name', None) + config = kwargs.get('config', None) + + swanlab.init( + project=project, + experiment_name=name, + config=config) + + return original_init(*args, **kwargs) + + def patched_log(*args, **kwargs): + data = args[1] + step = kwargs.get('step', None) + + # 过滤掉非标量类型 + filtered_data = {} + for key, value in data.items(): + if isinstance(value, (int, float, bool, str)): + filtered_data[key] = value + + print("Data:", filtered_data) + print("Step:", step) + + swanlab.log(data=filtered_data, step=step) + + return original_log(*args, **kwargs) + + def patched_finish(*args, **kwargs): + swanlab.finish() + return original_finish(*args, **kwargs) + + wandb.init = patched_init + wandb_sdk.wandb_run.Run.log = patched_log + wandb_sdk.wandb_run.Run.finish = patched_finish + + +if __name__ == "__main__": + import random + + # 在使用前调用apply_patches + sync_wandb_patches() + + wandb.init( + project="test", + config={"a": 1, "b": 2}, + name="test", + ) + + epochs = 10 + offset = random.random() / 5 + for epoch in range(2, epochs): + acc = 1 - 2 ** -epoch - random.random() / epoch - offset + loss = 2 ** -epoch + random.random() / epoch + offset + + # 记录训练指标 + wandb.log({"acc": acc, "loss": loss}) From 2539e4a1cb263cb7394121796aa0b7b39f8b5d0e Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Thu, 16 Jan 2025 22:59:49 +0800 Subject: [PATCH 2/6] make perfect --- swanlab/__init__.py | 2 ++ swanlab/sync/__init__.py | 3 +++ swanlab/{integration => sync}/wandb.py | 22 +++++++++++----------- 3 files changed, 16 insertions(+), 11 deletions(-) create mode 100644 swanlab/sync/__init__.py rename swanlab/{integration => sync}/wandb.py (83%) diff --git a/swanlab/__init__.py b/swanlab/__init__.py index 3cdaea31..d5cb4333 100755 --- a/swanlab/__init__.py +++ b/swanlab/__init__.py @@ -16,6 +16,7 @@ from .data.run.main import config from .package import get_package_version from .env import SwanLabEnv +from .sync import sync_wandb # 设置默认环境变量 SwanLabEnv.set_default() @@ -38,4 +39,5 @@ "get_config", "config", "__version__", + "sync_wandb", ] diff --git a/swanlab/sync/__init__.py b/swanlab/sync/__init__.py new file mode 100644 index 00000000..481f91b9 --- /dev/null +++ b/swanlab/sync/__init__.py @@ -0,0 +1,3 @@ +from .wandb import sync_wandb + +__all__ = ["sync_wandb"] diff --git a/swanlab/integration/wandb.py b/swanlab/sync/wandb.py similarity index 83% rename from swanlab/integration/wandb.py rename to swanlab/sync/wandb.py index f5e6421c..21e458ca 100644 --- a/swanlab/integration/wandb.py +++ b/swanlab/sync/wandb.py @@ -8,20 +8,23 @@ except ImportError: raise ImportError("please install wandb first, command: `pip install wandb`") -def sync_wandb_patches(): +def sync_wandb(): original_init = wandb.init original_log = wandb_sdk.wandb_run.Run.log - original_finish = wandb_sdk.wandb_run.Run.finish + original_finish = wandb_sdk.finish def patched_init(*args, **kwargs): project = kwargs.get('project', None) name = kwargs.get('name', None) config = kwargs.get('config', None) - swanlab.init( - project=project, - experiment_name=name, - config=config) + if swanlab.data.get_run() is None: + swanlab.init( + project=project, + experiment_name=name, + config=config) + else: + swanlab.config.update(config) return original_init(*args, **kwargs) @@ -35,9 +38,6 @@ def patched_log(*args, **kwargs): if isinstance(value, (int, float, bool, str)): filtered_data[key] = value - print("Data:", filtered_data) - print("Step:", step) - swanlab.log(data=filtered_data, step=step) return original_log(*args, **kwargs) @@ -54,8 +54,8 @@ def patched_finish(*args, **kwargs): if __name__ == "__main__": import random - # 在使用前调用apply_patches - sync_wandb_patches() + # 在使用前调用sync_wandb + sync_wandb() wandb.init( project="test", From 6833ba5917bf3ad66f39a24e58888d71d393d606 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Thu, 16 Jan 2025 23:11:14 +0800 Subject: [PATCH 3/6] fix import --- swanlab/sync/wandb.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/swanlab/sync/wandb.py b/swanlab/sync/wandb.py index 21e458ca..d8ad8ef2 100644 --- a/swanlab/sync/wandb.py +++ b/swanlab/sync/wandb.py @@ -1,14 +1,34 @@ """ -swanlab.init(sync_wandb=True) +import wandb +import random +import swanlab + +swanlab.sync_wandb() +# swanlab.init(project="sync_wandb") + +wandb.init( + project="test", + config={"a": 1, "b": 2}, + name="test", + ) + +epochs = 10 +offset = random.random() / 5 +for epoch in range(2, epochs): + acc = 1 - 2 ** -epoch - random.random() / epoch - offset + loss = 2 ** -epoch + random.random() / epoch + offset + + wandb.log({"acc": acc, "loss": loss}) """ import swanlab -try: - import wandb - from wandb import sdk as wandb_sdk -except ImportError: - raise ImportError("please install wandb first, command: `pip install wandb`") def sync_wandb(): + try: + import wandb + from wandb import sdk as wandb_sdk + except ImportError: + raise ImportError("please install wandb first, command: `pip install wandb`") + original_init = wandb.init original_log = wandb_sdk.wandb_run.Run.log original_finish = wandb_sdk.finish @@ -53,6 +73,7 @@ def patched_finish(*args, **kwargs): if __name__ == "__main__": import random + import wandb # 在使用前调用sync_wandb sync_wandb() From 17b01dc4df7e1a214692bb64ea11a563182e8c44 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Fri, 17 Jan 2025 12:52:14 +0800 Subject: [PATCH 4/6] del annotation --- swanlab/sync/wandb.py | 47 +------------------------------------------ 1 file changed, 1 insertion(+), 46 deletions(-) diff --git a/swanlab/sync/wandb.py b/swanlab/sync/wandb.py index d8ad8ef2..1791faf3 100644 --- a/swanlab/sync/wandb.py +++ b/swanlab/sync/wandb.py @@ -1,25 +1,3 @@ -""" -import wandb -import random -import swanlab - -swanlab.sync_wandb() -# swanlab.init(project="sync_wandb") - -wandb.init( - project="test", - config={"a": 1, "b": 2}, - name="test", - ) - -epochs = 10 -offset = random.random() / 5 -for epoch in range(2, epochs): - acc = 1 - 2 ** -epoch - random.random() / epoch - offset - loss = 2 ** -epoch + random.random() / epoch + offset - - wandb.log({"acc": acc, "loss": loss}) -""" import swanlab def sync_wandb(): @@ -68,27 +46,4 @@ def patched_finish(*args, **kwargs): wandb.init = patched_init wandb_sdk.wandb_run.Run.log = patched_log - wandb_sdk.wandb_run.Run.finish = patched_finish - - -if __name__ == "__main__": - import random - import wandb - - # 在使用前调用sync_wandb - sync_wandb() - - wandb.init( - project="test", - config={"a": 1, "b": 2}, - name="test", - ) - - epochs = 10 - offset = random.random() / 5 - for epoch in range(2, epochs): - acc = 1 - 2 ** -epoch - random.random() / epoch - offset - loss = 2 ** -epoch + random.random() / epoch + offset - - # 记录训练指标 - wandb.log({"acc": acc, "loss": loss}) + wandb_sdk.wandb_run.Run.finish = patched_finish \ No newline at end of file From fa369e3b8d6c1d014635f28e5962de1f34017d9b Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Fri, 17 Jan 2025 12:55:10 +0800 Subject: [PATCH 5/6] add test --- swanlab/sync/wandb.py | 20 ++++++++++++++++++++ test/integration/wandb/wandb_sync.py | 19 +++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 test/integration/wandb/wandb_sync.py diff --git a/swanlab/sync/wandb.py b/swanlab/sync/wandb.py index 1791faf3..049372a3 100644 --- a/swanlab/sync/wandb.py +++ b/swanlab/sync/wandb.py @@ -1,6 +1,26 @@ import swanlab def sync_wandb(): + """ + sync wandb with swanlab + + usecase: + ```python + import swanlab + swanlab.sync_wandb() + + wandb.init( + project="test", + config={"a": 1, "b": 2}, + name="test", + ) + + for epoch in range(10): + acc = 1 - 2 ** -epoch - random.random() / epoch - offset + loss = 2 ** -epoch + random.random() / epoch + offset + wandb.log({"acc": acc, "loss": loss}) + ``` + """ try: import wandb from wandb import sdk as wandb_sdk diff --git a/test/integration/wandb/wandb_sync.py b/test/integration/wandb/wandb_sync.py new file mode 100644 index 00000000..4e9f6ead --- /dev/null +++ b/test/integration/wandb/wandb_sync.py @@ -0,0 +1,19 @@ +import wandb +import random +import swanlab + +swanlab.sync_wandb() + +wandb.init( + project="test", + config={"a": 1, "b": 2}, + name="test", + ) + +epochs = 10 +offset = random.random() / 5 +for epoch in range(2, epochs): + acc = 1 - 2 ** -epoch - random.random() / epoch - offset + loss = 2 ** -epoch + random.random() / epoch + offset + + wandb.log({"acc": acc, "loss": loss}) \ No newline at end of file From 70e20f224cb6916487968a4934b66ac595c2fe40 Mon Sep 17 00:00:00 2001 From: KAAANG <79990647+SAKURA-CAT@users.noreply.github.com> Date: Fri, 17 Jan 2025 13:44:35 +0800 Subject: [PATCH 6/6] Update wandb_sync.py --- test/{integration/wandb => }/wandb_sync.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/{integration/wandb => }/wandb_sync.py (100%) diff --git a/test/integration/wandb/wandb_sync.py b/test/wandb_sync.py similarity index 100% rename from test/integration/wandb/wandb_sync.py rename to test/wandb_sync.py