Skip to content

Commit

Permalink
feat: sync wandb (#785)
Browse files Browse the repository at this point in the history
Co-authored-by: KAAANG <[email protected]>
  • Loading branch information
Zeyi-Lin and SAKURA-CAT authored Jan 17, 2025
1 parent 19bd1dc commit ffaaf97
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 0 deletions.
2 changes: 2 additions & 0 deletions swanlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .data.run.main import config
from .package import get_package_version
from .env import SwanLabEnv
from .sync import sync_wandb

# 设置默认环境变量
SwanLabEnv.set_default()
Expand All @@ -38,4 +39,5 @@
"get_config",
"config",
"__version__",
"sync_wandb",
]
3 changes: 3 additions & 0 deletions swanlab/sync/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .wandb import sync_wandb

__all__ = ["sync_wandb"]
69 changes: 69 additions & 0 deletions swanlab/sync/wandb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import swanlab

def sync_wandb():
"""
sync wandb with swanlab
usecase:
```python
import swanlab
swanlab.sync_wandb()
wandb.init(
project="test",
config={"a": 1, "b": 2},
name="test",
)
for epoch in range(10):
acc = 1 - 2 ** -epoch - random.random() / epoch - offset
loss = 2 ** -epoch + random.random() / epoch + offset
wandb.log({"acc": acc, "loss": loss})
```
"""
try:
import wandb
from wandb import sdk as wandb_sdk
except ImportError:
raise ImportError("please install wandb first, command: `pip install wandb`")

original_init = wandb.init
original_log = wandb_sdk.wandb_run.Run.log
original_finish = wandb_sdk.finish

def patched_init(*args, **kwargs):
project = kwargs.get('project', None)
name = kwargs.get('name', None)
config = kwargs.get('config', None)

if swanlab.data.get_run() is None:
swanlab.init(
project=project,
experiment_name=name,
config=config)
else:
swanlab.config.update(config)

return original_init(*args, **kwargs)

def patched_log(*args, **kwargs):
data = args[1]
step = kwargs.get('step', None)

# 过滤掉非标量类型
filtered_data = {}
for key, value in data.items():
if isinstance(value, (int, float, bool, str)):
filtered_data[key] = value

swanlab.log(data=filtered_data, step=step)

return original_log(*args, **kwargs)

def patched_finish(*args, **kwargs):
swanlab.finish()
return original_finish(*args, **kwargs)

wandb.init = patched_init
wandb_sdk.wandb_run.Run.log = patched_log
wandb_sdk.wandb_run.Run.finish = patched_finish
19 changes: 19 additions & 0 deletions test/wandb_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import wandb
import random
import swanlab

swanlab.sync_wandb()

wandb.init(
project="test",
config={"a": 1, "b": 2},
name="test",
)

epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
acc = 1 - 2 ** -epoch - random.random() / epoch - offset
loss = 2 ** -epoch + random.random() / epoch + offset

wandb.log({"acc": acc, "loss": loss})

0 comments on commit ffaaf97

Please sign in to comment.