Skip to content
7 changes: 7 additions & 0 deletions swanlab/core_python/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ def response_interceptor(response: requests.Response, *args, **kwargs):
"""
捕获所有的http不为2xx的错误,以ApiError的形式抛出
"""
# 1. 日志打印
swanlog.debug(
f"HTTP Request: {response.request.method.upper()} {response.url} | "
f"Response Status: {response.status_code} | "
f"Body: {decode_response(response)}"
)
# 2. 如果状态码不为2xx,抛出异常
if response.status_code // 100 != 2:
traceid = f"Trace id: {response.headers.get('traceid')}"
request = f"{response.request.method.upper()} {response.url}"
Expand Down
108 changes: 108 additions & 0 deletions swanlab/core_python/uploader/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
@author: cunyue
@file: batch.py
@time: 2025/12/4 13:50
@description: 分批上传函数
"""

import time
from typing import Union, Literal, List, TypedDict, Dict

from swanlab.core_python import get_client
from swanlab.log import swanlog


# 上传指标数据
class MetricDict(TypedDict):
projectId: str
experimentId: str
type: str
metrics: List[dict]
flagId: Union[str, None]


def create_data(metrics: List[dict], metrics_type: str) -> MetricDict:
"""
携带上传日志的指标信息
"""
client = get_client()
# Move 等实验需要将数据上传到根实验上
exp_id = client.exp.root_exp_cuid or client.exp.cuid
proj_id = client.exp.root_proj_cuid or client.proj.cuid
assert proj_id is not None, "Project ID is empty."
assert exp_id is not None, "Experiment ID is empty."
flag_id = client.exp.flag_id
return {
"projectId": proj_id,
"experimentId": exp_id,
"type": metrics_type,
"metrics": metrics,
"flagId": flag_id,
}


def _generate_chunks(data: Union[MetricDict, Dict, List], per_request_len: int):
"""
生成器:统一处理字典和列表的分片逻辑
yield: 分片后的数据块
"""
# 情况1: 不分批
if per_request_len == -1:
yield data
return

# 情况2: 字典分批
if isinstance(data, dict):
metrics = data.get('metrics', [])
# 如果 metrics 为空或长度不足,视为不需要分片的一整块
if len(metrics) <= per_request_len:
yield data
else:
for i in range(0, len(metrics), per_request_len):
yield {
**data,
"metrics": metrics[i : i + per_request_len],
}

# 情况3: 列表分批
elif isinstance(data, list):
if len(data) <= per_request_len:
yield data
else:
for i in range(0, len(data), per_request_len):
yield data[i : i + per_request_len]


def trace_metrics(
url: str,
data: Union[MetricDict, list] = None,
method: Literal['post', 'put'] = 'post',
per_request_len: int = 1000,
):
"""
分片指标上传方法
"""
# 判断是否开启了分片模式(用于决定是否 sleep)
# 这里的逻辑是:如果 per_request_len 不是 -1,且数据量确实超过了限制,则认为是分片模式
is_split_mode = False
if per_request_len != -1:
total_len = len(data.get('metrics', [])) if isinstance(data, dict) else len(data or [])
if total_len == 0:
return
is_split_mode = total_len > per_request_len
client = get_client()
# 遍历生成器产生的每一个数据块
for chunk in _generate_chunks(data, per_request_len):
# TODO: 暂时注释掉前置检查
# 如果在发送过程中 client 变成了 pending,则中断后续发送
# if client.pending:
# break

# 调用被装饰的发送函数
_, resp = getattr(client, method)(url, chunk)
# 后置检查
if resp and resp.status_code == 202:
client.pending = True
swanlog.warning(f"Client set to pending due to 202 response: {url}")
# 分批发送时需要 sleep
is_split_mode and time.sleep(1)
94 changes: 6 additions & 88 deletions swanlab/core_python/uploader/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,96 +5,15 @@
@description: 定义上传函数
"""

import time
from typing import List, Union, Literal, TypedDict
from typing import List

from swanlab.log import swanlog
from .batch import MetricDict, create_data, trace_metrics
from .model import ColumnModel, MediaModel, ScalarModel, FileModel, LogModel
from ..client import get_client, sync_error_handler, decode_response
from ...error import ApiError

house_url = '/house/metrics'


# 上传指标数据
class MetricDict(TypedDict):
projectId: str
experimentId: str
type: str
metrics: List[dict]
flagId: Union[str, None]


def create_data(metrics: List[dict], metrics_type: str) -> MetricDict:
"""
携带上传日志的指标信息
"""
client = get_client()
# Move 等实验需要将数据上传到根实验上
exp_id = client.exp.root_exp_cuid or client.exp.cuid
proj_id = client.exp.root_proj_cuid or client.proj.cuid
assert proj_id is not None, "Project ID is empty."
assert exp_id is not None, "Experiment ID is empty."
flag_id = client.exp.flag_id
return {
"projectId": proj_id,
"experimentId": exp_id,
"type": metrics_type,
"metrics": metrics,
"flagId": flag_id,
}


def trace_metrics(
url: str,
data: Union[MetricDict, list] = None,
method: Literal['post', 'put'] = 'post',
per_request_len: int = 1000,
):
"""
创建指标数据方法,如果 client 处于挂起状态,则不进行上传
:param url: 上传的URL地址
:param data: 上传的数据,可以是字典或列表
:param method: 请求方法,默认为 'post'
:param per_request_len: 每次请求的最大数据长度,如果设置为-1则不进行分批上传
"""
# TODO 用装饰器设置client的pending状态
client = get_client()
if client.pending:
return
if per_request_len == -1:
_, resp = getattr(client, method)(url, data)
if resp.status_code == 202:
client.pending = True
return
return
# 分批上传
if isinstance(data, dict):
need_split = len(data['metrics']) > per_request_len
# 1. 指标数据
for i in range(0, len(data['metrics']), per_request_len):
_, resp = getattr(client, method)(
url,
{
**data,
"metrics": data['metrics'][i : i + per_request_len],
},
)
if resp.status_code == 202:
client.pending = True
return
if need_split:
time.sleep(1)
else:
need_split = len(data) > per_request_len
# 2. 列表数据(列等)
for i in range(0, len(data), per_request_len):
_, resp = getattr(client, method)(url, data[i : i + per_request_len])
if resp.status_code == 202:
client.pending = True
return
if need_split:
time.sleep(1)
HOUSE_URL = '/house/metrics'


@sync_error_handler
Expand All @@ -109,7 +28,7 @@ def upload_logs(logs: List[LogModel]):
if len(metrics) == 0:
return swanlog.debug("No logs to upload.")
data = create_data(metrics, "log")
trace_metrics(house_url, data)
trace_metrics(HOUSE_URL, data)
return None


Expand All @@ -126,7 +45,7 @@ def upload_media_metrics(media_metrics: List[MediaModel]):
if not client.pending:
client.upload_files(buffers)
# 上传指标信息
trace_metrics(house_url, create_data([x.to_dict() for x in media_metrics], MediaModel.type.value))
trace_metrics(HOUSE_URL, create_data([x.to_dict() for x in media_metrics], MediaModel.type.value))


@sync_error_handler
Expand All @@ -135,7 +54,7 @@ def upload_scalar_metrics(scalar_metrics: List[ScalarModel]):
上传指标的标量数据
"""
data = create_data([x.to_dict() for x in scalar_metrics], ScalarModel.type.value)
trace_metrics(house_url, data)
trace_metrics(HOUSE_URL, data)


@sync_error_handler
Expand Down Expand Up @@ -189,7 +108,6 @@ def upload_columns(columns: List[ColumnModel]):


__all__ = [
"trace_metrics",
"MetricDict",
"upload_logs",
"upload_media_metrics",
Expand Down
5 changes: 4 additions & 1 deletion swanlab/data/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ def init(
# 注册settings
merge_settings(settings)
user_settings = get_settings()
swanlog.level = kwargs.get("log_level", "info")
# 加载日志级别配置
log_level = kwargs.get("log_level")
if log_level is not None:
swanlog.level = log_level
# 获取本地文件夹配置,默认从当前工作目录下的swanlog文件夹中读取
folder_settings = read_folder_settings(get_swanlog_dir())
# ---------------------------------- 一些变量、格式检查 ----------------------------------
Expand Down
4 changes: 4 additions & 0 deletions swanlab/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class SwanLabEnv(enum.Enum):
swanlab环境变量枚举类,包含swankit的共享环境变量
"""

LOG_LEVEL = "SWANLAB_LOG_LEVEL"
"""
swanlab日志的输出级别,默认为info,可选值有debug、info、warning、error、critical
"""
SWANLAB_FOLDER = "SWANLAB_SAVE_DIR"
"""
swanlab全局文件夹保存的路径,默认为用户主目录下的.swanlab文件夹
Expand Down
8 changes: 5 additions & 3 deletions swanlab/log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
@description: 标准输出、标准错误流拦截代理,支持外界设置/取消回调,基础作用为输出日志
"""

import os
import re
import sys
from typing import List, Tuple, Callable
from typing import List, Tuple, Callable, Union

from swanlab.env import create_time
from swanlab.env import create_time, SwanLabEnv
from swanlab.toolkit import SwanKitLogger
from .counter import AtomicCounter
from .type import LogHandler, LogType, WriteHandler, LogData, ProxyType
Expand All @@ -22,7 +23,8 @@ class SwanLog(SwanKitLogger):
继承自 SwanKitLogger 的同时增加标准输出、标准错误留拦截代理功能
"""

def __init__(self, name=__name__.lower(), level="info"):
def __init__(self, name=__name__.lower(), level: Union[str, None] = None):
level = (level or os.getenv(SwanLabEnv.LOG_LEVEL.value, 'info')).lower()
super().__init__(name=name, level=level)
self.__original_level = level
# 当前已经代理的输出行数
Expand Down
38 changes: 21 additions & 17 deletions swanlab/toolkit/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,16 @@
@description: swanlab 终端日志输出器,添加了一些好看的样式
"""

from typing import Literal, Union

from rich.console import Console
from rich.text import Text

__all__ = ['SwanKitLogger']

Levels = Union[Literal["debug", "info", "warning", "error", "critical"], str]


class SwanKitLogger:

def __init__(self, name=__name__.lower(), level: Levels = "info", file=None):
def __init__(self, name=__name__.lower(), level: str = "info", file=None):
self.console = Console(file=file)
self.__level: int = 0
self.__config = {
"debug": (
10,
Expand All @@ -43,20 +38,27 @@ def __init__(self, name=__name__.lower(), level: Levels = "info", file=None):
),
}
self.__can_log = True

self.level = level
# 默认日志等级为 info
self.__level: int = self.__config.get(level, self.__config["info"])[0]

@property
def level(self):
return self.__level
"""
获取当前日志等级
:return:
"""
for k, v in self.__config.items():
if v[0] == self.__level:
return k
raise AttributeError(f"level {self.__level} not found.")

@level.setter
def level(self, level: Levels):
def level(self, level: str):
"""
设置日志等级
:param level: 日志等级,可选值为 debug, info, warning, error, critical,如果传入的值不在可选值中,则默认为 info
"""
if level not in ("debug", "info", "warning", "error", "critical"):
if level not in self.__config:
_level = 20 # info
else:
_level = self.__config[level][0]
Expand All @@ -80,12 +82,14 @@ def __print(self, log_level: str, *args, **kwargs):
"""
if not self.__can_log:
return
level, prefix = self.__config[log_level]
if level < self.__level:
return
if kwargs.get("sep") == '':
prefix += " "
self.console.print(prefix, *args, **kwargs)
if log_level in self.__config:
level, prefix = self.__config[log_level]
if level < self.__level:
return
# 处理 sep 参数,即使设置了 sep='' ,前缀后也会有一个空格
if kwargs.get("sep") == '':
prefix += " "
self.console.print(prefix, *args, **kwargs)

# 发送调试消息
def debug(self, *args, **kwargs):
Expand Down
6 changes: 6 additions & 0 deletions test/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,15 @@ def setup_each():
os.mkdir(TEMP_PATH)
yield
import swanlab
from swanlab.log import swanlog

if swanlab.get_run() is not None:
swanlab.finish()
# 终端代理有可能没有释放
try:
swanlog.reset()
except RuntimeError:
pass
from swanlab.data.store import reset_run_store
from swanlab.data.porter import DataPorter

Expand Down
Loading