Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions swanlab/core_python/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from swanlab.error import ApiError
from swanlab.log import swanlog
from swanlab.package import get_package_version
from .session import create_session
from .session import create_session, SessionWithRetry
from .utils import safe_request, ProjectInfo, ExperimentInfo
from .. import auth
from ..api.experiment import send_experiment_heartbeat
Expand Down Expand Up @@ -49,7 +49,7 @@ class Client:
def __init__(self, login_info: auth.LoginInfo):
self.__login_info = login_info
# 当前会话
self.__session: Optional[requests.Session] = None
self.__session: Optional[SessionWithRetry] = None
self.__version = get_package_version()
self.__create_session()

Expand Down Expand Up @@ -139,9 +139,9 @@ def __create_session(self):
创建会话,这将在HTTP类实例化时调用
添加了重试策略
"""
session = create_session()
session.headers["swanlab-sdk"] = self.__version
session.cookies.update({"sid": self.__login_info.sid})
swr = create_session()
swr.headers["swanlab-sdk"] = self.__version
swr.cookies.update({"sid": self.__login_info.sid})

# 注册响应钩子
def response_interceptor(response: requests.Response, *args, **kwargs):
Expand All @@ -161,44 +161,44 @@ def response_interceptor(response: requests.Response, *args, **kwargs):
resp = f"{response.status_code} {response.reason}"
raise ApiError(response, traceid, request, resp)

session.hooks["response"] = response_interceptor
swr.hooks["response"] = response_interceptor

self.__session = session
self.__session = swr

def post(self, url: str, data: Union[dict, list] = None):
def post(self, url: str, data: Union[dict, list] = None, retries: Optional[int] = None):
"""
post请求
"""
url = self.__login_info.api_host + url
self.__before_request()
resp = self.__session.post(url, json=data)
resp = self.__session.post(url, json=data, retries=retries)
return decode_response(resp), resp

def put(self, url: str, data: dict = None):
def put(self, url: str, data: dict = None, retries: Optional[int] = None):
"""
put请求
"""
url = self.__login_info.api_host + url
self.__before_request()
resp = self.__session.put(url, json=data)
resp = self.__session.put(url, json=data, retries=retries)
return decode_response(resp), resp

def get(self, url: str, params: dict = None):
def get(self, url: str, params: dict = None, retries: Optional[int] = None):
"""
get请求
"""
url = self.__login_info.api_host + url
self.__before_request()
resp = self.__session.get(url, params=params)
resp = self.__session.get(url, params=params, retries=retries)
return decode_response(resp), resp

def patch(self, url: str, data: dict = None):
def patch(self, url: str, data: dict = None, retries: Optional[int] = None):
"""
patch请求
"""
url = self.__login_info.api_host + url
self.__before_request()
resp = self.__session.patch(url, json=data)
resp = self.__session.patch(url, json=data, retries=retries)
return decode_response(resp), resp

# ---------------------------------- 训练相关接口 ----------------------------------
Expand Down
69 changes: 66 additions & 3 deletions swanlab/core_python/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,26 @@
@description: 创建会话
"""

import copy
from typing import Optional

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from swanlab.package import get_package_version

RETRY_HEADER = "X-Custom-Retry"
# 设置默认超时时间为 60s
DEFAULT_TIMEOUT = 60


# 创建一个自定义的 HTTPAdapter,用于注入默认超时
class TimeoutHTTPAdapter(HTTPAdapter):
"""
创建一个自定义的 HTTPAdapter,用于注入默认超时
并可以通过请求 headers 中的 {RETRY_HEADER} 指定并临时修改重试次数
"""

def __init__(self, *args, **kwargs):
# 从 kwargs 中取出默认超时时间,如果没有则设为 None
self.timeout = kwargs.pop("timeout", None)
Expand All @@ -27,15 +35,70 @@ def send(self, request, **kwargs):
if "timeout" not in kwargs and self.timeout is not None:
kwargs["timeout"] = self.timeout

# 检查 headers 中是否有用户注入的重试次数
_retry = request.headers.pop(RETRY_HEADER, None)
if _retry is not None:
_adapter = copy.copy(self)
try:
retry_count = int(_retry)
if retry_count < 0:
raise ValueError("Retry count must be a non-negative integer.")
_adapter.max_retries = self.max_retries.new(total=retry_count)
except ValueError:
raise ValueError(
f"Invalid retry count in {RETRY_HEADER}: '{_retry}'. Must be a non-negative integer."
) from None

return _adapter.send(request, **kwargs)

return super().send(request, **kwargs)


def create_session() -> requests.Session:
class SessionWithRetry(requests.Session):
"""
自定义会话,用于自定义会话重试次数
可以接受一个 retries 参数,并将其放在 headers 中的 {RETRY_HEADER} 字段中
"""

def request(self, method, url, *args, **kwargs):
retries = kwargs.pop('retries', None)

# 将用户指定的重试次数注入到 headers 中
if retries is not None:
kwargs.setdefault('headers', {})[RETRY_HEADER] = str(retries)

return super().request(method, url, *args, **kwargs)

# ---------------------------------- 重写方法的函数签名,避免IDE警告 ----------------------------------

def get(self, url, params=None, retries: Optional[int] = None, **kwargs):
return self.request("GET", url, params=params, retries=retries, **kwargs)

def options(self, url, retries: Optional[int] = None, **kwargs):
return self.request("OPTIONS", url, retries=retries, **kwargs)

def head(self, url, retries: Optional[int] = None, **kwargs):
return self.request("HEAD", url, retries=retries, **kwargs)

def post(self, url, data=None, json=None, retries: Optional[int] = None, **kwargs):
return self.request("POST", url, data=data, json=json, retries=retries, **kwargs)

def put(self, url, data=None, retries: Optional[int] = None, **kwargs):
return self.request("PUT", url, data=data, retries=retries, **kwargs)

def patch(self, url, data=None, retries: Optional[int] = None, **kwargs):
return self.request("PATCH", url, data=data, retries=retries, **kwargs)

def delete(self, url, retries: Optional[int] = None, **kwargs):
return self.request("DELETE", url, retries=retries, **kwargs)


def create_session() -> SessionWithRetry:
"""
创建一个带重试机制的会话
:return: requests.Session
"""
session = requests.Session()
session = SessionWithRetry()
retry = Retry(
total=5,
backoff_factor=0.5,
Expand Down
2 changes: 1 addition & 1 deletion swanlab/core_python/uploader/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def trace_metrics(
# break

# 调用被装饰的发送函数
_, resp = getattr(client, method)(url, chunk)
_, resp = getattr(client, method)(url, chunk, retries=0)
# 后置检查
if resp and resp.status_code == 202:
client.pending = True
Expand Down
31 changes: 31 additions & 0 deletions test/unit/core_python/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import requests
import responses
from requests.adapters import HTTPAdapter
from requests.exceptions import RetryError
from responses import registries
from urllib3.util.retry import Retry

Expand All @@ -34,6 +35,36 @@ def test_retry(url):
assert len(responses.calls) == 6


@pytest.mark.parametrize(
("url", "retries"), [("https://api.example.com/retry", 1), ("http://api.example.com/retry", 0)]
)
@responses.activate(registry=registries.OrderedRegistry)
def test_custom_retry(url, retries):
"""
测试自定义重试次数
"""

[responses.add(responses.POST, url, body="Error", status=500) for _ in range(2)]
s = create_session()
with pytest.raises(RetryError):
s.post(url, retries=retries)

assert len(responses.calls) == retries + 1


@pytest.mark.parametrize("url", ["https://api.example.com/retry"])
@responses.activate(registry=registries.OrderedRegistry)
def test_custom_retry_with_not_number(url):
"""
测试自定义重试次数
"""

[responses.add(responses.POST, url, body="Error", status=500) for _ in range(2)]
s = create_session()
with pytest.raises(ValueError):
s.post(url, retries="not-a-number")


@responses.activate(registry=registries.OrderedRegistry)
def test_session_headers():
"""
Expand Down