Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions swanlab/core_python/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ def response_interceptor(response: requests.Response, *args, **kwargs):
"""
捕获所有的http不为2xx的错误,以ApiError的形式抛出
"""
# 1. 日志打印
swanlog.debug(
f"HTTP Request: {response.request.method.upper()} {response.url} | "
f"Response Status: {response.status_code} | "
f"Body: {decode_response(response)}"
)
# 2. 如果状态码不为2xx,抛出异常
if response.status_code // 100 != 2:
traceid = f"Trace id: {response.headers.get('traceid')}"
request = f"{response.request.method.upper()} {response.url}"
Expand Down
20 changes: 19 additions & 1 deletion swanlab/core_python/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@

from swanlab.package import get_package_version

# 设置默认超时时间为 60s
DEFAULT_TIMEOUT = 60


# 创建一个自定义的 HTTPAdapter,用于注入默认超时
class TimeoutHTTPAdapter(HTTPAdapter):
def __init__(self, *args, **kwargs):
# 从 kwargs 中取出默认超时时间,如果没有则设为 None
self.timeout = kwargs.pop("timeout", None)
super().__init__(*args, **kwargs)

def send(self, request, **kwargs):
# 如果 kwargs 中没有显式设置 timeout,则使用 self.timeout
if "timeout" not in kwargs and self.timeout is not None:
kwargs["timeout"] = self.timeout

return super().send(request, **kwargs)


def create_session() -> requests.Session:
"""
Expand All @@ -24,7 +42,7 @@ def create_session() -> requests.Session:
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=frozenset(["GET", "POST", "PUT", "DELETE", "PATCH"]),
)
adapter = HTTPAdapter(max_retries=retry)
adapter = TimeoutHTTPAdapter(max_retries=retry, timeout=DEFAULT_TIMEOUT)
session.mount("https://", adapter)
session.mount("http://", adapter)
session.headers["swanlab-sdk"] = get_package_version()
Expand Down
108 changes: 108 additions & 0 deletions swanlab/core_python/uploader/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
@author: cunyue
@file: batch.py
@time: 2025/12/4 13:50
@description: 分批上传函数
"""

import time
from typing import Union, Literal, List, TypedDict, Dict

from swanlab.core_python import get_client
from swanlab.log import swanlog


# 上传指标数据
class MetricDict(TypedDict):
projectId: str
experimentId: str
type: str
metrics: List[dict]
flagId: Union[str, None]


def create_data(metrics: List[dict], metrics_type: str) -> MetricDict:
"""
携带上传日志的指标信息
"""
client = get_client()
# Move 等实验需要将数据上传到根实验上
exp_id = client.exp.root_exp_cuid or client.exp.cuid
proj_id = client.exp.root_proj_cuid or client.proj.cuid
assert proj_id is not None, "Project ID is empty."
assert exp_id is not None, "Experiment ID is empty."
flag_id = client.exp.flag_id
return {
"projectId": proj_id,
"experimentId": exp_id,
"type": metrics_type,
"metrics": metrics,
"flagId": flag_id,
}


def _generate_chunks(data: Union[MetricDict, Dict, List], per_request_len: int):
"""
生成器:统一处理字典和列表的分片逻辑
yield: 分片后的数据块
"""
# 情况1: 不分批
if per_request_len == -1:
yield data
return

# 情况2: 字典分批
if isinstance(data, dict):
metrics = data.get('metrics', [])
# 如果 metrics 为空或长度不足,视为不需要分片的一整块
if len(metrics) <= per_request_len:
yield data
else:
for i in range(0, len(metrics), per_request_len):
yield {
**data,
"metrics": metrics[i : i + per_request_len],
}

# 情况3: 列表分批
elif isinstance(data, list):
if len(data) <= per_request_len:
yield data
else:
for i in range(0, len(data), per_request_len):
yield data[i : i + per_request_len]


def trace_metrics(
url: str,
data: Union[MetricDict, list] = None,
method: Literal['post', 'put'] = 'post',
per_request_len: int = 1000,
):
"""
分片指标上传方法
"""
# 判断是否开启了分片模式(用于决定是否 sleep)
# 这里的逻辑是:如果 per_request_len 不是 -1,且数据量确实超过了限制,则认为是分片模式
is_split_mode = False
if per_request_len != -1:
total_len = len(data.get('metrics', [])) if isinstance(data, dict) else len(data or [])
if total_len == 0:
return
is_split_mode = total_len > per_request_len
client = get_client()
# 遍历生成器产生的每一个数据块
for chunk in _generate_chunks(data, per_request_len):
# TODO: 暂时注释掉前置检查
# 如果在发送过程中 client 变成了 pending,则中断后续发送
# if client.pending:
# break

# 调用被装饰的发送函数
_, resp = getattr(client, method)(url, chunk)
# 后置检查
if resp and resp.status_code == 202:
client.pending = True
swanlog.warning(f"Client set to pending due to 202 response: {url}")
# 分批发送时需要 sleep
is_split_mode and time.sleep(1)
94 changes: 6 additions & 88 deletions swanlab/core_python/uploader/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,96 +5,15 @@
@description: 定义上传函数
"""

import time
from typing import List, Union, Literal, TypedDict
from typing import List

from swanlab.log import swanlog
from .batch import MetricDict, create_data, trace_metrics
from .model import ColumnModel, MediaModel, ScalarModel, FileModel, LogModel
from ..client import get_client, sync_error_handler, decode_response
from ...error import ApiError

house_url = '/house/metrics'


# 上传指标数据
class MetricDict(TypedDict):
projectId: str
experimentId: str
type: str
metrics: List[dict]
flagId: Union[str, None]


def create_data(metrics: List[dict], metrics_type: str) -> MetricDict:
"""
携带上传日志的指标信息
"""
client = get_client()
# Move 等实验需要将数据上传到根实验上
exp_id = client.exp.root_exp_cuid or client.exp.cuid
proj_id = client.exp.root_proj_cuid or client.proj.cuid
assert proj_id is not None, "Project ID is empty."
assert exp_id is not None, "Experiment ID is empty."
flag_id = client.exp.flag_id
return {
"projectId": proj_id,
"experimentId": exp_id,
"type": metrics_type,
"metrics": metrics,
"flagId": flag_id,
}


def trace_metrics(
url: str,
data: Union[MetricDict, list] = None,
method: Literal['post', 'put'] = 'post',
per_request_len: int = 1000,
):
"""
创建指标数据方法,如果 client 处于挂起状态,则不进行上传
:param url: 上传的URL地址
:param data: 上传的数据,可以是字典或列表
:param method: 请求方法,默认为 'post'
:param per_request_len: 每次请求的最大数据长度,如果设置为-1则不进行分批上传
"""
# TODO 用装饰器设置client的pending状态
client = get_client()
if client.pending:
return
if per_request_len == -1:
_, resp = getattr(client, method)(url, data)
if resp.status_code == 202:
client.pending = True
return
return
# 分批上传
if isinstance(data, dict):
need_split = len(data['metrics']) > per_request_len
# 1. 指标数据
for i in range(0, len(data['metrics']), per_request_len):
_, resp = getattr(client, method)(
url,
{
**data,
"metrics": data['metrics'][i : i + per_request_len],
},
)
if resp.status_code == 202:
client.pending = True
return
if need_split:
time.sleep(1)
else:
need_split = len(data) > per_request_len
# 2. 列表数据(列等)
for i in range(0, len(data), per_request_len):
_, resp = getattr(client, method)(url, data[i : i + per_request_len])
if resp.status_code == 202:
client.pending = True
return
if need_split:
time.sleep(1)
HOUSE_URL = '/house/metrics'


@sync_error_handler
Expand All @@ -109,7 +28,7 @@ def upload_logs(logs: List[LogModel]):
if len(metrics) == 0:
return swanlog.debug("No logs to upload.")
data = create_data(metrics, "log")
trace_metrics(house_url, data)
trace_metrics(HOUSE_URL, data)
return None


Expand All @@ -126,7 +45,7 @@ def upload_media_metrics(media_metrics: List[MediaModel]):
if not client.pending:
client.upload_files(buffers)
# 上传指标信息
trace_metrics(house_url, create_data([x.to_dict() for x in media_metrics], MediaModel.type.value))
trace_metrics(HOUSE_URL, create_data([x.to_dict() for x in media_metrics], MediaModel.type.value))


@sync_error_handler
Expand All @@ -135,7 +54,7 @@ def upload_scalar_metrics(scalar_metrics: List[ScalarModel]):
上传指标的标量数据
"""
data = create_data([x.to_dict() for x in scalar_metrics], ScalarModel.type.value)
trace_metrics(house_url, data)
trace_metrics(HOUSE_URL, data)


@sync_error_handler
Expand Down Expand Up @@ -189,7 +108,6 @@ def upload_columns(columns: List[ColumnModel]):


__all__ = [
"trace_metrics",
"MetricDict",
"upload_logs",
"upload_media_metrics",
Expand Down
12 changes: 7 additions & 5 deletions swanlab/data/modules/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,20 @@ def parse(self, **kwargs) -> Optional[ParseResult]:
buffers.append(r)
more.append(i.get_more())
result.strings = data
# 过滤掉空列表
result.buffers = self.__filter_list(buffers)
result.more = self.__filter_list(more)
# more 字段需要保留列表结构,即使全为 None,否则前端访问 more[0] 会报错
result.more = self.__filter_list(more, keep_none_list=True)
self.__result = result
return self.__result

@staticmethod
def __filter_list(li: List):
def __filter_list(li: List, keep_none_list: bool = False) -> Optional[List]:
"""
如果li长度大于0且如果l内部不全是None,返回l,否则返回None
过滤列表。默认情况下,如果列表为空或所有元素都为None,则返回None
:param li: 待过滤的列表
:param keep_none_list: 是否保留全为 None 的列表
"""
if len(li) > 0 and any(i is not None for i in li):
if li and (keep_none_list or any(i is not None for i in li)):
return li
return None

Expand Down
5 changes: 4 additions & 1 deletion swanlab/data/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ def init(
# 注册settings
merge_settings(settings)
user_settings = get_settings()
swanlog.level = kwargs.get("log_level", "info")
# 加载日志级别配置
log_level = kwargs.get("log_level")
if log_level is not None:
swanlog.level = log_level
# 获取本地文件夹配置,默认从当前工作目录下的swanlog文件夹中读取
folder_settings = read_folder_settings(get_swanlog_dir())
# ---------------------------------- 一些变量、格式检查 ----------------------------------
Expand Down
4 changes: 4 additions & 0 deletions swanlab/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class SwanLabEnv(enum.Enum):
swanlab环境变量枚举类,包含swankit的共享环境变量
"""

LOG_LEVEL = "SWANLAB_LOG_LEVEL"
"""
swanlab日志的输出级别,默认为info,可选值有debug、info、warning、error、critical
"""
SWANLAB_FOLDER = "SWANLAB_SAVE_DIR"
"""
swanlab全局文件夹保存的路径,默认为用户主目录下的.swanlab文件夹
Expand Down
8 changes: 5 additions & 3 deletions swanlab/log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
@description: 标准输出、标准错误流拦截代理,支持外界设置/取消回调,基础作用为输出日志
"""

import os
import re
import sys
from typing import List, Tuple, Callable
from typing import List, Tuple, Callable, Union

from swanlab.env import create_time
from swanlab.env import create_time, SwanLabEnv
from swanlab.toolkit import SwanKitLogger
from .counter import AtomicCounter
from .type import LogHandler, LogType, WriteHandler, LogData, ProxyType
Expand All @@ -22,7 +23,8 @@ class SwanLog(SwanKitLogger):
继承自 SwanKitLogger 的同时增加标准输出、标准错误留拦截代理功能
"""

def __init__(self, name=__name__.lower(), level="info"):
def __init__(self, name=__name__.lower(), level: Union[str, None] = None):
level = (level or os.getenv(SwanLabEnv.LOG_LEVEL.value, 'info')).lower()
super().__init__(name=name, level=level)
self.__original_level = level
# 当前已经代理的输出行数
Expand Down
1 change: 1 addition & 0 deletions swanlab/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
"SlackCallback",
"LogdirFileWriter",
"BarkCallback",
"TelegramCallback",
]
Loading