From 6a9c93e238aaf5950b6dca4c728eb4c21f0bea62 Mon Sep 17 00:00:00 2001 From: Ze-Yi LIN <58305964+Zeyi-Lin@users.noreply.github.com> Date: Tue, 21 Jan 2025 18:06:26 +0800 Subject: [PATCH] feat: sync enhancement (#799) --- swanlab/sync/tensorboard.py | 154 +++++++++++++++++++++++++++------ swanlab/sync/wandb.py | 52 +++++++---- test/sync_tensorboardX.py | 3 +- test/sync_tensorboard_torch.py | 3 +- test/sync_wandb.py | 6 +- 5 files changed, 173 insertions(+), 45 deletions(-) diff --git a/swanlab/sync/tensorboard.py b/swanlab/sync/tensorboard.py index f2fd9353..b41255fb 100644 --- a/swanlab/sync/tensorboard.py +++ b/swanlab/sync/tensorboard.py @@ -1,13 +1,33 @@ import swanlab +def _extract_args(args, kwargs, param_names): + """ + 从args和kwargs中提取参数值的通用函数 + + Args: + args: 位置参数元组 + kwargs: 关键字参数字典 + param_names: 参数名称列表 + + Returns: + tuple: 按param_names顺序返回提取的参数值 + """ + values = [] + for i, name in enumerate(param_names): + if len(args) > i: + values.append(args[i]) + else: + values.append(kwargs.get(name, None)) + return tuple(values) + def sync_tensorboardX(): """ 同步tensorboardX到swanlab - + from tensorboardX import SummaryWriter import numpy as np import swanlab - + swanlab.sync_tensorboardX() writer = SummaryWriter('runs/example') @@ -21,51 +41,92 @@ def sync_tensorboardX(): from tensorboardX import SummaryWriter except ImportError: raise ImportError("please install tensorboardX first, command: `pip install tensorboardX`") - + original_init = SummaryWriter.__init__ original_add_scalar = SummaryWriter.add_scalar + original_add_image = SummaryWriter.add_image original_close = SummaryWriter.close - def patched_init(self, *args, **kwargs): - logdir = args[0] - + + def patched_init(self, *args, **kwargs): + logdir, _, _, _, _, _, _, log_dir, _ = _extract_args(args, kwargs, ['logdir', 'comment', 'purge_step', 'max_queue', 'flush_secs', 'filename_suffix', 'write_to_disk', 'log_dir', 'comet_config']) + tb_logdir = logdir or log_dir + tb_config = { - 'tensorboard_logdir': logdir, + 'tensorboard_logdir': tb_logdir, } - + if swanlab.data.get_run() is None: swanlab.init(config=tb_config) else: swanlab.config.update(tb_config) - + return original_init(self, *args, **kwargs) - - def patched_add_scalar(self, tag, scalar_value, global_step=None): + + def patched_add_scalar(self, *args, **kwargs): + tag, scalar_value, global_step = _extract_args( + args, kwargs, ['tag', 'scalar_value', 'global_step'] + ) + data = {tag: scalar_value} swanlab.log(data=data, step=global_step) - return original_add_scalar(self, tag, scalar_value, global_step) + return original_add_scalar(self, *args, **kwargs) + + def patched_add_image(self, *args, **kwargs): + import numpy as np + + tag, img_tensor, global_step, dataformats = _extract_args( + args, kwargs, ['tag', 'img_tensor', 'global_step', 'dataformats'] + ) + dataformats = dataformats or 'CHW' # 设置默认值 + + # Convert to numpy array if it's a tensor + if hasattr(img_tensor, 'cpu'): + img_tensor = img_tensor.cpu() + if hasattr(img_tensor, 'numpy'): + img_tensor = img_tensor.numpy() + + # Handle different input formats + if dataformats == 'CHW': + # Convert CHW to HWC for swanlab + img_tensor = np.transpose(img_tensor, (1, 2, 0)) + elif dataformats == 'NCHW': + # Take first image if batch dimension exists and convert to HWC + img_tensor = np.transpose(img_tensor, (1, 2, 0)) + elif dataformats == 'HW': + # Add channel dimension for grayscale + img_tensor = np.expand_dims(img_tensor, axis=-1) + elif dataformats == 'HWC': + # Already in correct format + pass + data = {tag: swanlab.Image(img_tensor)} + swanlab.log(data=data, step=global_step) + + return original_add_image(self, *args, **kwargs) + def patched_close(self): # 调用原始的close方法 original_close(self) # 关闭SwanLab记录器 swanlab.finish() - + # 应用monkey patch SummaryWriter.__init__ = patched_init SummaryWriter.add_scalar = patched_add_scalar + SummaryWriter.add_image = patched_add_image SummaryWriter.close = patched_close def sync_tensorboard_torch(): """ 同步torch自带的tensorboard到swanlab - + from torch.utils.tensorboard import SummaryWriter import numpy as np import swanlab - + swanlab.sync_tensorboard_torch() writer = SummaryWriter('runs/example') @@ -82,37 +143,76 @@ def sync_tensorboard_torch(): original_init = SummaryWriter.__init__ original_add_scalar = SummaryWriter.add_scalar + original_add_image = SummaryWriter.add_image original_close = SummaryWriter.close - + def patched_init(self, *args, **kwargs): - logdir = args[0] - + logdir, comment = _extract_args(args, kwargs, ['log_dir', 'comment']) + tb_logdir = logdir + tb_config = { - 'tensorboard_logdir': logdir, + 'tensorboard_logdir': tb_logdir, } - + if swanlab.data.get_run() is None: swanlab.init(config=tb_config) else: swanlab.config.update(tb_config) - + return original_init(self, *args, **kwargs) - - def patched_add_scalar(self, tag, scalar_value, global_step=None): + + def patched_add_scalar(self, *args, **kwargs): + tag, scalar_value, global_step = _extract_args( + args, kwargs, ['tag', 'scalar_value', 'global_step'] + ) + data = {tag: scalar_value} swanlab.log(data=data, step=global_step) - return original_add_scalar(self, tag, scalar_value, global_step) + return original_add_scalar(self, *args, **kwargs) + + def patched_add_image(self, *args, **kwargs): + import numpy as np + + tag, img_tensor, global_step, dataformats = _extract_args( + args, kwargs, ['tag', 'img_tensor', 'global_step', 'dataformats'] + ) + + dataformats = dataformats or 'CHW' # 设置默认值 + + # Convert to numpy array if it's a tensor + if hasattr(img_tensor, 'cpu'): + img_tensor = img_tensor.cpu() + if hasattr(img_tensor, 'numpy'): + img_tensor = img_tensor.numpy() + # Handle different input formats + if dataformats == 'CHW': + # Convert CHW to HWC for swanlab + img_tensor = np.transpose(img_tensor, (1, 2, 0)) + elif dataformats == 'NCHW': + # Take first image if batch dimension exists and convert to HWC + img_tensor = np.transpose(img_tensor, (1, 2, 0)) + elif dataformats == 'HW': + # Add channel dimension for grayscale + img_tensor = np.expand_dims(img_tensor, axis=-1) + elif dataformats == 'HWC': + # Already in correct format + pass + + data = {tag: swanlab.Image(img_tensor)} + swanlab.log(data=data, step=global_step) + + return original_add_image(self, *args, **kwargs) + def patched_close(self): # 调用原始的close方法 original_close(self) # 关闭SwanLab记录器 swanlab.finish() - + # 应用monkey patch SummaryWriter.__init__ = patched_init SummaryWriter.add_scalar = patched_add_scalar + SummaryWriter.add_image = patched_add_image SummaryWriter.close = patched_close - - \ No newline at end of file diff --git a/swanlab/sync/wandb.py b/swanlab/sync/wandb.py index bb496e0d..cee7eec4 100644 --- a/swanlab/sync/wandb.py +++ b/swanlab/sync/wandb.py @@ -1,5 +1,26 @@ import swanlab +def _extract_args(args, kwargs, param_names): + """ + 从args和kwargs中提取参数值的通用函数 + + Args: + args: 位置参数元组 + kwargs: 关键字参数字典 + param_names: 参数名称列表 + + Returns: + tuple: 按param_names顺序返回提取的参数值 + """ + values = [] + for i, name in enumerate(param_names): + if len(args) > i: + values.append(args[i]) + else: + values.append(kwargs.get(name, None)) + return tuple(values) + + def sync_wandb(mode:str="cloud", wandb_run:bool=True): """ sync wandb with swanlab, 暂时不支持log非标量类型 @@ -43,14 +64,15 @@ def sync_wandb(mode:str="cloud", wandb_run:bool=True): original_config_update = wandb_sdk.wandb_config.Config.update def patched_init(*args, **kwargs): - project = kwargs.get('project', None) - name = kwargs.get('name', None) - config = kwargs.get('config', None) + entity, project, dir, id, name, notes, tags, config, config_exclude_keys = _extract_args( + args, kwargs, ['entity', 'project', 'dir', 'id', 'name', 'notes', 'tags', 'config', 'config_exclude_keys'] + ) if swanlab.data.get_run() is None: swanlab.init( project=project, experiment_name=name, + description=notes, config=config, mode=mode) else: @@ -62,18 +84,18 @@ def patched_init(*args, **kwargs): else: return original_init(*args, **kwargs) - def patched_config_update(*args, **kwargs): - try: - config= args[1] - except: - config= None - if config is not None: - swanlab.config.update(config) - return original_config_update(*args, **kwargs) + def patched_config_update(self, *args, **kwargs): + d, _ = _extract_args(args, kwargs, ['d', 'allow_val_change']) + + if d is not None: + swanlab.config.update(d) + return original_config_update(self, *args, **kwargs) - def patched_log(*args, **kwargs): - data = args[1] - step = kwargs.get('step', None) + def patched_log(self, *args, **kwargs): + data, step, commit, sync = _extract_args(args, kwargs, ['data', 'step', 'commit', 'sync']) + + if data is None: + return original_log(self, *args, **kwargs) # 过滤掉非标量类型 filtered_data = {} @@ -83,7 +105,7 @@ def patched_log(*args, **kwargs): swanlab.log(data=filtered_data, step=step) - return original_log(*args, **kwargs) + return original_log(self, *args, **kwargs) def patched_finish(*args, **kwargs): swanlab.finish() diff --git a/test/sync_tensorboardX.py b/test/sync_tensorboardX.py index a317c474..baf05b9a 100644 --- a/test/sync_tensorboardX.py +++ b/test/sync_tensorboardX.py @@ -11,9 +11,10 @@ import swanlab swanlab.sync_tensorboardX() - writer = SummaryWriter('runs/example') +writer.add_image('random_image', np.random.randint(0, 255, (3, 100, 100)), global_step=20) + for i in range(100): scalar_value = np.random.rand() writer.add_scalar('random_scalar', scalar_value, i) diff --git a/test/sync_tensorboard_torch.py b/test/sync_tensorboard_torch.py index 0c93b599..e3fdef02 100644 --- a/test/sync_tensorboard_torch.py +++ b/test/sync_tensorboard_torch.py @@ -11,9 +11,10 @@ import swanlab swanlab.sync_tensorboard_torch() - writer = SummaryWriter('runs/example') +writer.add_image('random_image', np.random.randint(0, 255, (3, 100, 100)), global_step=20) + for i in range(100): scalar_value = np.random.rand() writer.add_scalar('random_scalar', scalar_value, i) diff --git a/test/sync_wandb.py b/test/sync_wandb.py index 4e9f6ead..d9122b19 100644 --- a/test/sync_wandb.py +++ b/test/sync_wandb.py @@ -2,14 +2,18 @@ import random import swanlab -swanlab.sync_wandb() +swanlab.sync_wandb(wandb_run=False) wandb.init( project="test", config={"a": 1, "b": 2}, name="test", + notes="test_wandb_sync", ) +wandb.config.update({"c": 3, "d": 4}) +print(swanlab.config.get("c")) + epochs = 10 offset = random.random() / 5 for epoch in range(2, epochs):