Skip to content

Commit

Permalink
feat: sync enhancement (#799)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeyi-Lin authored Jan 21, 2025
1 parent ec4312c commit 6a9c93e
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 45 deletions.
154 changes: 127 additions & 27 deletions swanlab/sync/tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
import swanlab

def _extract_args(args, kwargs, param_names):
"""
从args和kwargs中提取参数值的通用函数
Args:
args: 位置参数元组
kwargs: 关键字参数字典
param_names: 参数名称列表
Returns:
tuple: 按param_names顺序返回提取的参数值
"""
values = []
for i, name in enumerate(param_names):
if len(args) > i:
values.append(args[i])
else:
values.append(kwargs.get(name, None))
return tuple(values)

def sync_tensorboardX():
"""
同步tensorboardX到swanlab
from tensorboardX import SummaryWriter
import numpy as np
import swanlab
swanlab.sync_tensorboardX()
writer = SummaryWriter('runs/example')
Expand All @@ -21,51 +41,92 @@ def sync_tensorboardX():
from tensorboardX import SummaryWriter
except ImportError:
raise ImportError("please install tensorboardX first, command: `pip install tensorboardX`")

original_init = SummaryWriter.__init__
original_add_scalar = SummaryWriter.add_scalar
original_add_image = SummaryWriter.add_image
original_close = SummaryWriter.close

def patched_init(self, *args, **kwargs):
logdir = args[0]


def patched_init(self, *args, **kwargs):
logdir, _, _, _, _, _, _, log_dir, _ = _extract_args(args, kwargs, ['logdir', 'comment', 'purge_step', 'max_queue', 'flush_secs', 'filename_suffix', 'write_to_disk', 'log_dir', 'comet_config'])
tb_logdir = logdir or log_dir

tb_config = {
'tensorboard_logdir': logdir,
'tensorboard_logdir': tb_logdir,
}

if swanlab.data.get_run() is None:
swanlab.init(config=tb_config)
else:
swanlab.config.update(tb_config)

return original_init(self, *args, **kwargs)

def patched_add_scalar(self, tag, scalar_value, global_step=None):

def patched_add_scalar(self, *args, **kwargs):
tag, scalar_value, global_step = _extract_args(
args, kwargs, ['tag', 'scalar_value', 'global_step']
)

data = {tag: scalar_value}
swanlab.log(data=data, step=global_step)

return original_add_scalar(self, tag, scalar_value, global_step)
return original_add_scalar(self, *args, **kwargs)

def patched_add_image(self, *args, **kwargs):
import numpy as np

tag, img_tensor, global_step, dataformats = _extract_args(
args, kwargs, ['tag', 'img_tensor', 'global_step', 'dataformats']
)
dataformats = dataformats or 'CHW' # 设置默认值

# Convert to numpy array if it's a tensor
if hasattr(img_tensor, 'cpu'):
img_tensor = img_tensor.cpu()
if hasattr(img_tensor, 'numpy'):
img_tensor = img_tensor.numpy()

# Handle different input formats
if dataformats == 'CHW':
# Convert CHW to HWC for swanlab
img_tensor = np.transpose(img_tensor, (1, 2, 0))
elif dataformats == 'NCHW':
# Take first image if batch dimension exists and convert to HWC
img_tensor = np.transpose(img_tensor, (1, 2, 0))
elif dataformats == 'HW':
# Add channel dimension for grayscale
img_tensor = np.expand_dims(img_tensor, axis=-1)
elif dataformats == 'HWC':
# Already in correct format
pass

data = {tag: swanlab.Image(img_tensor)}
swanlab.log(data=data, step=global_step)

return original_add_image(self, *args, **kwargs)

def patched_close(self):
# 调用原始的close方法
original_close(self)
# 关闭SwanLab记录器
swanlab.finish()

# 应用monkey patch
SummaryWriter.__init__ = patched_init
SummaryWriter.add_scalar = patched_add_scalar
SummaryWriter.add_image = patched_add_image
SummaryWriter.close = patched_close


def sync_tensorboard_torch():
"""
同步torch自带的tensorboard到swanlab
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import swanlab
swanlab.sync_tensorboard_torch()
writer = SummaryWriter('runs/example')
Expand All @@ -82,37 +143,76 @@ def sync_tensorboard_torch():

original_init = SummaryWriter.__init__
original_add_scalar = SummaryWriter.add_scalar
original_add_image = SummaryWriter.add_image
original_close = SummaryWriter.close

def patched_init(self, *args, **kwargs):
logdir = args[0]

logdir, comment = _extract_args(args, kwargs, ['log_dir', 'comment'])
tb_logdir = logdir

tb_config = {
'tensorboard_logdir': logdir,
'tensorboard_logdir': tb_logdir,
}

if swanlab.data.get_run() is None:
swanlab.init(config=tb_config)
else:
swanlab.config.update(tb_config)

return original_init(self, *args, **kwargs)

def patched_add_scalar(self, tag, scalar_value, global_step=None):

def patched_add_scalar(self, *args, **kwargs):
tag, scalar_value, global_step = _extract_args(
args, kwargs, ['tag', 'scalar_value', 'global_step']
)

data = {tag: scalar_value}
swanlab.log(data=data, step=global_step)

return original_add_scalar(self, tag, scalar_value, global_step)
return original_add_scalar(self, *args, **kwargs)

def patched_add_image(self, *args, **kwargs):
import numpy as np

tag, img_tensor, global_step, dataformats = _extract_args(
args, kwargs, ['tag', 'img_tensor', 'global_step', 'dataformats']
)

dataformats = dataformats or 'CHW' # 设置默认值

# Convert to numpy array if it's a tensor
if hasattr(img_tensor, 'cpu'):
img_tensor = img_tensor.cpu()
if hasattr(img_tensor, 'numpy'):
img_tensor = img_tensor.numpy()

# Handle different input formats
if dataformats == 'CHW':
# Convert CHW to HWC for swanlab
img_tensor = np.transpose(img_tensor, (1, 2, 0))
elif dataformats == 'NCHW':
# Take first image if batch dimension exists and convert to HWC
img_tensor = np.transpose(img_tensor, (1, 2, 0))
elif dataformats == 'HW':
# Add channel dimension for grayscale
img_tensor = np.expand_dims(img_tensor, axis=-1)
elif dataformats == 'HWC':
# Already in correct format
pass

data = {tag: swanlab.Image(img_tensor)}
swanlab.log(data=data, step=global_step)

return original_add_image(self, *args, **kwargs)

def patched_close(self):
# 调用原始的close方法
original_close(self)
# 关闭SwanLab记录器
swanlab.finish()

# 应用monkey patch
SummaryWriter.__init__ = patched_init
SummaryWriter.add_scalar = patched_add_scalar
SummaryWriter.add_image = patched_add_image
SummaryWriter.close = patched_close


52 changes: 37 additions & 15 deletions swanlab/sync/wandb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,26 @@
import swanlab

def _extract_args(args, kwargs, param_names):
"""
从args和kwargs中提取参数值的通用函数
Args:
args: 位置参数元组
kwargs: 关键字参数字典
param_names: 参数名称列表
Returns:
tuple: 按param_names顺序返回提取的参数值
"""
values = []
for i, name in enumerate(param_names):
if len(args) > i:
values.append(args[i])
else:
values.append(kwargs.get(name, None))
return tuple(values)


def sync_wandb(mode:str="cloud", wandb_run:bool=True):
"""
sync wandb with swanlab, 暂时不支持log非标量类型
Expand Down Expand Up @@ -43,14 +64,15 @@ def sync_wandb(mode:str="cloud", wandb_run:bool=True):
original_config_update = wandb_sdk.wandb_config.Config.update

def patched_init(*args, **kwargs):
project = kwargs.get('project', None)
name = kwargs.get('name', None)
config = kwargs.get('config', None)
entity, project, dir, id, name, notes, tags, config, config_exclude_keys = _extract_args(
args, kwargs, ['entity', 'project', 'dir', 'id', 'name', 'notes', 'tags', 'config', 'config_exclude_keys']
)

if swanlab.data.get_run() is None:
swanlab.init(
project=project,
experiment_name=name,
description=notes,
config=config,
mode=mode)
else:
Expand All @@ -62,18 +84,18 @@ def patched_init(*args, **kwargs):
else:
return original_init(*args, **kwargs)

def patched_config_update(*args, **kwargs):
try:
config= args[1]
except:
config= None
if config is not None:
swanlab.config.update(config)
return original_config_update(*args, **kwargs)
def patched_config_update(self, *args, **kwargs):
d, _ = _extract_args(args, kwargs, ['d', 'allow_val_change'])

if d is not None:
swanlab.config.update(d)
return original_config_update(self, *args, **kwargs)

def patched_log(*args, **kwargs):
data = args[1]
step = kwargs.get('step', None)
def patched_log(self, *args, **kwargs):
data, step, commit, sync = _extract_args(args, kwargs, ['data', 'step', 'commit', 'sync'])

if data is None:
return original_log(self, *args, **kwargs)

# 过滤掉非标量类型
filtered_data = {}
Expand All @@ -83,7 +105,7 @@ def patched_log(*args, **kwargs):

swanlab.log(data=filtered_data, step=step)

return original_log(*args, **kwargs)
return original_log(self, *args, **kwargs)

def patched_finish(*args, **kwargs):
swanlab.finish()
Expand Down
3 changes: 2 additions & 1 deletion test/sync_tensorboardX.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import swanlab

swanlab.sync_tensorboardX()

writer = SummaryWriter('runs/example')

writer.add_image('random_image', np.random.randint(0, 255, (3, 100, 100)), global_step=20)

for i in range(100):
scalar_value = np.random.rand()
writer.add_scalar('random_scalar', scalar_value, i)
Expand Down
3 changes: 2 additions & 1 deletion test/sync_tensorboard_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import swanlab

swanlab.sync_tensorboard_torch()

writer = SummaryWriter('runs/example')

writer.add_image('random_image', np.random.randint(0, 255, (3, 100, 100)), global_step=20)

for i in range(100):
scalar_value = np.random.rand()
writer.add_scalar('random_scalar', scalar_value, i)
Expand Down
6 changes: 5 additions & 1 deletion test/sync_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
import random
import swanlab

swanlab.sync_wandb()
swanlab.sync_wandb(wandb_run=False)

wandb.init(
project="test",
config={"a": 1, "b": 2},
name="test",
notes="test_wandb_sync",
)

wandb.config.update({"c": 3, "d": 4})
print(swanlab.config.get("c"))

epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
Expand Down

0 comments on commit 6a9c93e

Please sign in to comment.