Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,9 @@ runs/

# test-integration-lightning
LightningTest/

# claude
CLAUDE.md
AGENTS.md
FEATURE_SUMMARY.md
.claude/
92 changes: 65 additions & 27 deletions swanlab/sync/tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import functools
import swanlab


def _extract_args(args, kwargs, param_names):
"""
从args和kwargs中提取参数值的通用函数

Args:
args: 位置参数元组
kwargs: 关键字参数字典
param_names: 参数名称列表

Returns:
tuple: 按param_names顺序返回提取的参数值
"""
Expand All @@ -21,17 +23,22 @@ def _extract_args(args, kwargs, param_names):
return tuple(values)


def _create_patched_methods(SummaryWriter, logdir_extractor):
def _create_patched_methods(SummaryWriter, logdir_extractor, types=None):
"""
创建patched方法的工厂函数

Args:
SummaryWriter: SummaryWriter类
logdir_extractor: 提取logdir的函数

types: 要同步的数据类型列表,如 ['scalar', 'scalars', 'image', 'text']。
None 表示同步所有类型。

Returns:
tuple: (patched_init, patched_add_scalar, patched_add_image, patched_close)
"""
# 将 types 转换为 set 以提高查找性能
types_set = set(types) if types is not None else None

original_init = SummaryWriter.__init__
original_add_scalar = SummaryWriter.add_scalar
original_add_scalars = SummaryWriter.add_scalars
Expand All @@ -41,7 +48,7 @@ def _create_patched_methods(SummaryWriter, logdir_extractor):

def patched_init(self, *args, **kwargs):
tb_logdir = logdir_extractor(args, kwargs)

tb_config = {
'tensorboard_logdir': tb_logdir,
}
Expand All @@ -53,17 +60,23 @@ def patched_init(self, *args, **kwargs):

return original_init(self, *args, **kwargs)

@functools.wraps(original_add_scalar)
def patched_add_scalar(self, *args, **kwargs):
if types_set is not None and 'scalar' not in types_set:
return original_add_scalar(self, *args, **kwargs)
tag, scalar_value, global_step = _extract_args(
args, kwargs, ['tag', 'scalar_value', 'global_step']
)

data = {tag: scalar_value}
swanlab.log(data=data, step=int(global_step))

return original_add_scalar(self, *args, **kwargs)

@functools.wraps(original_add_scalars)
def patched_add_scalars(self, *args, **kwargs):
if types_set is not None and 'scalars' not in types_set:
return original_add_scalars(self, *args, **kwargs)
# writer.add_scalars('Loss', {'train': loss_train, 'val': loss_val}, global_step=step)
tag, scalar_value_dict, global_step = _extract_args(
args, kwargs, ['tag', 'scalar_value_dict', 'global_step']
Expand All @@ -73,20 +86,23 @@ def patched_add_scalars(self, *args, **kwargs):
swanlab.log(data=data, step=int(global_step))
return original_add_scalars(self, *args, **kwargs)

@functools.wraps(original_add_image)
def patched_add_image(self, *args, **kwargs):
if types_set is not None and 'image' not in types_set:
return original_add_image(self, *args, **kwargs)
import numpy as np

tag, img_tensor, global_step, dataformats = _extract_args(
args, kwargs, ['tag', 'img_tensor', 'global_step', 'dataformats']
)
dataformats = dataformats or 'CHW' # 设置默认值

# Convert to numpy array if it's a tensor
if hasattr(img_tensor, 'cpu'):
img_tensor = img_tensor.cpu()
if hasattr(img_tensor, 'numpy'):
img_tensor = img_tensor.numpy()

# Handle different input formats
if dataformats == 'CHW':
# Convert CHW to HWC for swanlab
Expand All @@ -100,20 +116,22 @@ def patched_add_image(self, *args, **kwargs):
elif dataformats == 'HWC':
# Already in correct format
pass

data = {tag: swanlab.Image(img_tensor)}
swanlab.log(data=data, step=int(global_step))

return original_add_image(self, *args, **kwargs)


@functools.wraps(original_add_text)
def patched_add_text(self, *args, **kwargs):
if types_set is not None and 'text' not in types_set:
return original_add_text(self, *args, **kwargs)
tag, text_string, global_step = _extract_args(
args, kwargs, ['tag', 'text_string', 'global_step']
)
data = {tag: swanlab.Text(text_string)}
swanlab.log(data=data, step=int(global_step))
return original_add_text(self, *args, **kwargs)


def patched_close(self):
# 调用原始的close方法
Expand Down Expand Up @@ -142,78 +160,98 @@ def _apply_patches(SummaryWriter, patched_methods):
SummaryWriter.close = patched_close


def _sync_tensorboard_generic(import_func, logdir_extractor):
def _sync_tensorboard_generic(import_func, logdir_extractor, types=None):
"""
通用的tensorboard同步函数

Args:
import_func: 导入SummaryWriter的函数
logdir_extractor: 提取logdir的函数
types: 要同步的数据类型列表,如 ['scalar', 'scalars', 'image', 'text']。
None 表示同步所有类型。
"""
try:
SummaryWriter = import_func()
except ImportError as e:
raise ImportError(f"Import failed: {e}")

patched_methods = _create_patched_methods(SummaryWriter, logdir_extractor)
patched_methods = _create_patched_methods(SummaryWriter, logdir_extractor, types)
_apply_patches(SummaryWriter, patched_methods)


def sync_tensorboardX():
def sync_tensorboardX(types=None):
"""
同步tensorboardX到swanlab

from tensorboardX import SummaryWriter
import numpy as np
import swanlab

# 同步所有类型
swanlab.sync_tensorboardX()

# 只同步标量数据
swanlab.sync_tensorboardX(types=['scalar', 'scalars'])

writer = SummaryWriter('runs/example')

for i in range(100):
scalar_value = np.random.rand()
writer.add_scalar('random_scalar', scalar_value, i)

writer.close()

Args:
types: 要同步的数据类型列表,可选值: 'scalar', 'scalars', 'image', 'text'。
None 表示同步所有类型。
"""
def import_tensorboardx():
from tensorboardX import SummaryWriter
return SummaryWriter

def extract_logdir_tensorboardx(args, kwargs):
logdir, _, _, _, _, _, _, log_dir, _ = _extract_args(
args, kwargs,
['logdir', 'comment', 'purge_step', 'max_queue', 'flush_secs',
args, kwargs,
['logdir', 'comment', 'purge_step', 'max_queue', 'flush_secs',
'filename_suffix', 'write_to_disk', 'log_dir', 'comet_config']
)
return logdir or log_dir

_sync_tensorboard_generic(import_tensorboardx, extract_logdir_tensorboardx)

_sync_tensorboard_generic(import_tensorboardx, extract_logdir_tensorboardx, types)

def sync_tensorboard_torch():

def sync_tensorboard_torch(types=None):
"""
同步torch自带的tensorboard到swanlab

from torch.utils.tensorboard import SummaryWriter
import numpy as np
import swanlab

# 同步所有类型
swanlab.sync_tensorboard_torch()

# 只同步标量数据(排除文本、图像等)
swanlab.sync_tensorboard_torch(types=['scalar', 'scalars'])

writer = SummaryWriter('runs/example')

for i in range(100):
scalar_value = np.random.rand()
writer.add_scalar('random_scalar', scalar_value, i)

writer.close()

Args:
types: 要同步的数据类型列表,可选值: 'scalar', 'scalars', 'image', 'text'。
None 表示同步所有类型。
"""
def import_torch_tensorboard():
from torch.utils.tensorboard import SummaryWriter
return SummaryWriter

def extract_logdir_torch(args, kwargs):
logdir, _ = _extract_args(args, kwargs, ['log_dir', 'comment'])
return logdir
_sync_tensorboard_generic(import_torch_tensorboard, extract_logdir_torch)

_sync_tensorboard_generic(import_torch_tensorboard, extract_logdir_torch, types)
5 changes: 4 additions & 1 deletion test/sync_tensorboardX.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@

import swanlab

swanlab.sync_tensorboardX()
# 只同步标量数据(排除文本、图像等)
swanlab.sync_tensorboardX(types=['scalar', 'scalars'])
writer = SummaryWriter('runs/example')

# 这些不会被同步(因为 types 过滤)
writer.add_image('random_image', np.random.randint(0, 255, (3, 100, 100)), global_step=20)
writer.add_text('random_text', 'hello', global_step=10)

for i in range(100):
scalar_value = np.random.rand()
# 这些会被同步(标量数据)
writer.add_scalar('random_scalar', scalar_value, i)
writer.add_scalars('random_scalars', {'scalar1': scalar_value, 'scalar2': scalar_value * 2}, i)

Expand Down
5 changes: 4 additions & 1 deletion test/sync_tensorboard_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@

import swanlab

swanlab.sync_tensorboard_torch()
# 只同步标量数据(排除文本、图像等)
swanlab.sync_tensorboard_torch(types=['scalar', 'scalars'])
writer = SummaryWriter('runs/example')

# 这些不会被同步(因为 types 过滤)
writer.add_image('random_image', np.random.randint(0, 255, (3, 100, 100)), global_step=20)
writer.add_text('random_text', 'hello', global_step=10)

for i in range(100):
scalar_value = np.random.rand()
# 这些会被同步(标量数据)
writer.add_scalar('random_scalar', scalar_value, i)
writer.add_scalars('random_scalars', {'scalar1': scalar_value, 'scalar2': scalar_value * 2}, i)

Expand Down