Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions swanlab/cli/commands/converter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@
type=str,
help="The directory where the tensorboard log files are stored.",
)
@click.option(
"--tb-types",
default=None,
type=str,
help="The types of the tensorboard log files to convert, default is all types.",
)

# wandb
@click.option(
Expand Down Expand Up @@ -103,6 +109,7 @@ def convert(
workspace: str,
logdir: str,
tb_logdir: str,
tb_types: str,
wb_project: str,
wb_entity: str,
wb_runid: str,
Expand All @@ -122,6 +129,7 @@ def convert(
workspace=workspace,
mode=mode,
logdir=logdir,
types=tb_types,
)
tfb_converter.run()

Expand Down
51 changes: 29 additions & 22 deletions swanlab/converter/tfb/tfb_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from datetime import datetime
from ._utils import find_tfevents, get_tf_events_tags_type, get_tf_events_tags_data
from swanlab.log import swanlog as swl
import time as t
import time


SUPPORTED_TYPES = ["scalar", "image", "audio", "text"]

class TFBConverter:
def __init__(
self,
Expand All @@ -14,13 +16,23 @@ def __init__(
workspace: str = None,
mode: str = "cloud",
logdir: str = None,
types: str = None,
**kwargs,
):
self.convert_dir = convert_dir
self.project = project
self.workspace = workspace
self.mode = mode
self.logdir = logdir
self.types = types
if self.types is None:
self.types = SUPPORTED_TYPES
else:
self.types = self.types.split(",")
self.types = [t.strip().lower() for t in self.types]
self.types = list(set(self.types))
if not all(type in SUPPORTED_TYPES for type in self.types):
raise ValueError(f"Unsupported types: {self.types}")

def run(self, depth=3):
swl.info("Start converting TFEvent files to SwanLab format...")
Expand Down Expand Up @@ -70,31 +82,26 @@ def run(self, depth=3):
data_by_tags = get_tf_events_tags_data(path, type_by_tags)

times = []
# 遍历数据
if data_by_tags:
# 打印并转换数据到SwanLab
handlers = {
"scalar": lambda v: v,
"image": lambda v: swanlab.Image(v),
"audio": lambda v: swanlab.Audio(v[0], sample_rate=v[1]),
"text": lambda v: swanlab.Text(v),
}
index = 0
for tag, data in data_by_tags.items():
for step, value, time in data:
times.append(time)
# 如果是标量
if type_by_tags[tag] == "scalar":
swanlab.log({tag: value}, step=step)
# 如果是图片
elif type_by_tags[tag] == "image":
swanlab.log({tag: swanlab.Image(value)}, step=step)
# 如果是音频
elif type_by_tags[tag] == "audio":
swanlab.log({tag: swanlab.Audio(value[0], sample_rate=value[1])}, step=step)
# 如果是文本
elif type_by_tags[tag] == "text":
swanlab.log({tag: swanlab.Text(value)}, step=step)
# TODO: 随着SwanLab的发展,支持转换更多类型
# TODO: 等未来上传方案优化后解除延时
if index % 5 == 0:
t.sleep(1)
tag_type = type_by_tags[tag]
if tag_type not in self.types:
continue
handler = handlers[tag_type]
index += 1
print(f"Index {index}: Metric: {tag} log finished")
for step, value, t in data:
times.append(t)
swanlab.log({tag: handler(value)}, step=step)
print(f"Metric [{index}]: {tag} log finished")
if index % 5 == 0:
time.sleep(1)

# 计算完整的运行时间
runtime = max(times) - min(times)
Expand Down