Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sync wandb #785

Merged
merged 6 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions swanlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .data.run.main import config
from .package import get_package_version
from .env import SwanLabEnv
from .sync import sync_wandb

# 设置默认环境变量
SwanLabEnv.set_default()
Expand All @@ -38,4 +39,5 @@
"get_config",
"config",
"__version__",
"sync_wandb",
]
3 changes: 3 additions & 0 deletions swanlab/sync/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .wandb import sync_wandb

__all__ = ["sync_wandb"]
69 changes: 69 additions & 0 deletions swanlab/sync/wandb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import swanlab

def sync_wandb():
"""
sync wandb with swanlab

usecase:
```python
import swanlab
swanlab.sync_wandb()

wandb.init(
project="test",
config={"a": 1, "b": 2},
name="test",
)

for epoch in range(10):
acc = 1 - 2 ** -epoch - random.random() / epoch - offset
loss = 2 ** -epoch + random.random() / epoch + offset
wandb.log({"acc": acc, "loss": loss})
```
"""
try:
import wandb
from wandb import sdk as wandb_sdk
except ImportError:
raise ImportError("please install wandb first, command: `pip install wandb`")

original_init = wandb.init
original_log = wandb_sdk.wandb_run.Run.log
original_finish = wandb_sdk.finish

def patched_init(*args, **kwargs):
project = kwargs.get('project', None)
name = kwargs.get('name', None)
config = kwargs.get('config', None)

if swanlab.data.get_run() is None:
swanlab.init(
project=project,
experiment_name=name,
config=config)
else:
swanlab.config.update(config)

return original_init(*args, **kwargs)

def patched_log(*args, **kwargs):
data = args[1]
step = kwargs.get('step', None)

# 过滤掉非标量类型
filtered_data = {}
for key, value in data.items():
if isinstance(value, (int, float, bool, str)):
filtered_data[key] = value

swanlab.log(data=filtered_data, step=step)

return original_log(*args, **kwargs)

def patched_finish(*args, **kwargs):
swanlab.finish()
return original_finish(*args, **kwargs)

wandb.init = patched_init
wandb_sdk.wandb_run.Run.log = patched_log
wandb_sdk.wandb_run.Run.finish = patched_finish
19 changes: 19 additions & 0 deletions test/wandb_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import wandb
import random
import swanlab

swanlab.sync_wandb()

wandb.init(
project="test",
config={"a": 1, "b": 2},
name="test",
)

epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
acc = 1 - 2 ** -epoch - random.random() / epoch - offset
loss = 2 ** -epoch + random.random() / epoch + offset

wandb.log({"acc": acc, "loss": loss})
Loading