Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sync wandb add params #794

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions swanlab/sync/wandb.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
import swanlab

def sync_wandb():
def sync_wandb(mode:str="cloud", wandb_run:bool=True):
"""
sync wandb with swanlab
sync wandb with swanlab, 暂时不支持log非标量类型

- mode: "cloud", "local" or "disabled". https://docs.swanlab.cn/api/py-init.html
- wandb_run: 如果此参数设置为False,则不会将数据上传到wandb,等同于设置wandb.init(mode="offline")。

usecase:
```python
import wandb
import random
import swanlab

swanlab.sync_wandb()

# swanlab.init(project="sync_wandb")

wandb.init(
project="test",
config={"a": 1, "b": 2},
name="test",
)

for epoch in range(10):
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
acc = 1 - 2 ** -epoch - random.random() / epoch - offset
loss = 2 ** -epoch + random.random() / epoch + offset

wandb.log({"acc": acc, "loss": loss})
```
"""
Expand All @@ -30,6 +40,7 @@ def sync_wandb():
original_init = wandb.init
original_log = wandb_sdk.wandb_run.Run.log
original_finish = wandb_sdk.finish
original_config_update = wandb_sdk.wandb_config.Config.update

def patched_init(*args, **kwargs):
project = kwargs.get('project', None)
Expand All @@ -40,11 +51,25 @@ def patched_init(*args, **kwargs):
swanlab.init(
project=project,
experiment_name=name,
config=config)
config=config,
mode=mode)
else:
swanlab.config.update(config)

return original_init(*args, **kwargs)
if wandb_run is False:
kwargs["mode"] = "offline"
return original_init(*args, **kwargs)
else:
return original_init(*args, **kwargs)

def patched_config_update(*args, **kwargs):
try:
config= args[1]
except:
config= None
if config is not None:
swanlab.config.update(config)
return original_config_update(*args, **kwargs)

def patched_log(*args, **kwargs):
data = args[1]
Expand All @@ -66,4 +91,5 @@ def patched_finish(*args, **kwargs):

wandb.init = patched_init
wandb_sdk.wandb_run.Run.log = patched_log
wandb_sdk.wandb_run.Run.finish = patched_finish
wandb_sdk.wandb_run.Run.finish = patched_finish
wandb_sdk.wandb_config.Config.update = patched_config_update
Loading