Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .idea/SwanLab.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions .idea/dictionaries/project.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion swanlab/core_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
# FIXME 存在循环引用,我们需要更优雅的代码结构
# from . import auth
# from . import uploader
from .client import Client, create_client, reset_client, get_client, create_session
from .client import *
from .utils import timer
22 changes: 19 additions & 3 deletions swanlab/core_python/api/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,29 @@
@description: 定义实验相关的后端API接口
"""

from typing import Literal
from typing import Literal, TYPE_CHECKING

from swanlab.core_python.client import Client
if TYPE_CHECKING:
from swanlab.core_python.client import Client


def send_experiment_heartbeat(
client: "Client",
*,
cuid: str,
flag_id: str,
):
"""
发送实验心跳,保持实验处于活跃状态
:param client: 已登录的客户端实例
:param cuid: 实验唯一标识符
:param flag_id: 实验标记ID
"""
client.post(f"/house/experiments/{cuid}/heartbeat", {"flagId": flag_id})


def update_experiment_state(
client: Client,
client: "Client",
*,
username: str,
projname: str,
Expand Down
59 changes: 22 additions & 37 deletions swanlab/core_python/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,15 @@
from typing import Optional, Tuple, Dict, Union, List, AnyStr

import requests
from urllib3.exceptions import (
MaxRetryError,
TimeoutError,
NewConnectionError,
ConnectionError,
ReadTimeoutError,
ConnectTimeoutError,
)

from swanlab.error import NetworkError, ApiError

from swanlab.error import ApiError
from swanlab.log import swanlog
from swanlab.package import get_package_version
from .model import ProjectInfo, ExperimentInfo
from .session import create_session
from .utils import safe_request, ProjectInfo, ExperimentInfo
from .. import auth
from ..api.experiment import send_experiment_heartbeat
from ..utils import timer
from ...env import utc_time


Expand Down Expand Up @@ -114,11 +108,11 @@ def expname(self):
return self.__exp.name

@property
def web_proj_url(self):
def web_proj_url(self) -> str:
return f"{self.__login_info.web_host}/@{self.groupname}/{self.projname}"

@property
def web_exp_url(self):
def web_exp_url(self) -> str:
return f"{self.web_proj_url}/runs/{self.exp_id}"

# ---------------------------------- http方法 ----------------------------------
Expand Down Expand Up @@ -374,41 +368,32 @@ def reset_client():
client = None


def safe_request(func):
def create_client_heartbeat(interval: int = 10 * 60):
"""
在一些接口中我们不希望线程奔溃,而是返回一个错误对象
创建客户端心跳定时器,保持实验处于活跃状态
:param interval: 心跳间隔,单位秒,默认10分钟
:return: 心跳定时器实例
"""
cl = get_client()

def wrapper(*args, **kwargs) -> Tuple[Optional[Union[dict, str]], Optional[Exception]]:
# TODO 目前保证向下兼容,如果报错也不提示用户,后续使用safe_request装饰器
# func = safe_request(func=send_experiment_heartbeat)
def func(c: Client, *, cuid: str, flag_id: str):
try:
# 在装饰器中调用被装饰的异步函数
result = func(*args, **kwargs)
return result, None
except requests.exceptions.Timeout:
return None, NetworkError()
except requests.exceptions.ConnectionError:
return None, NetworkError()
# Catch urllib3 specific errors
except (
MaxRetryError,
TimeoutError,
NewConnectionError,
ConnectionError,
ReadTimeoutError,
ConnectTimeoutError,
):
return None, NetworkError()
except Exception as e:
return None, e

return wrapper
send_experiment_heartbeat(c, cuid=cuid, flag_id=flag_id)
except ApiError as e:
swanlog.debug(f"Failed to send heartbeat: {e}")

task = lambda: func(cl, cuid=cl.exp.cuid, flag_id=cl.exp.flag_id)
return timer.Timer(task, interval=interval, immediate=True).run()


__all__ = [
"get_client",
"reset_client",
"create_session",
"create_client",
"create_client_heartbeat",
"safe_request",
"decode_response",
"Client",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,53 @@
"""
@author: cunyue
@file: model.py
@time: 2025/6/16 14:55
@description: 实验、项目元信息
@file: utils.py
@time: 2025/12/31 13:29
@description: 客户端工具函数
"""

from typing import Optional
from typing import Tuple, Optional, Union

import requests
from urllib3.exceptions import (
MaxRetryError,
TimeoutError,
NewConnectionError,
ConnectionError,
ReadTimeoutError,
ConnectTimeoutError,
)

from swanlab.error import NetworkError


def safe_request(func):
"""
在一些接口中我们不希望线程奔溃,而是返回一个错误对象
"""

def wrapper(*args, **kwargs) -> Tuple[Optional[Union[dict, str]], Optional[Exception]]:
try:
# 在装饰器中调用被装饰的异步函数
result = func(*args, **kwargs)
return result, None
except requests.exceptions.Timeout:
return None, NetworkError()
except requests.exceptions.ConnectionError:
return None, NetworkError()
# Catch urllib3 specific errors
except (
MaxRetryError,
TimeoutError,
NewConnectionError,
ConnectionError,
ReadTimeoutError,
ConnectTimeoutError,
):
return None, NetworkError()
except Exception as e:
return None, e

return wrapper


class ProjectInfo:
Expand Down
12 changes: 11 additions & 1 deletion swanlab/data/callbacker/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..run import get_run
from ...core_python import *
from ...core_python.api.experiment import update_experiment_state
from ...core_python.utils.timer import Timer
from ...log.type import LogData


Expand All @@ -36,6 +37,7 @@ class CloudPyCallback(SwanLabRunCallback):
def __init__(self):
super().__init__()
self.executor = ThreadPoolExecutor(max_workers=1)
self.heartbeat: Optional[Timer] = None

def __str__(self):
return "SwanLabCloudPyCallback"
Expand All @@ -62,12 +64,15 @@ def _converter_summarise_metric():
pass

def on_init(self, *args, **kwargs):
_ = self._create_client()
self._create_client()
# 检测是否有最新的版本
U.check_latest_version()
# 挂载项目、实验
with Status("Creating experiment...", spinner="dots"):
with Mounter() as mounter:
mounter.execute()
# 创建客户端心跳
self.heartbeat = create_client_heartbeat()

def _terminal_handler(self, log_data: LogData):
self.porter.trace_log(log_data)
Expand Down Expand Up @@ -100,6 +105,11 @@ def on_metric_create(self, metric_info: MetricInfo, *args, **kwargs):
self.porter.trace_metric(metric_info)

def on_stop(self, error: str = None, *args, **kwargs):
# 删除心跳
if self.heartbeat:
self.heartbeat.cancel()
self.heartbeat.join()
# 删除终端代理和系统回调
success = get_run().success
# FIXME 等合并 swankit 以后优化一下 interrupt 的传递问题
interrupt = kwargs.get("interrupt", False)
Expand Down