Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions test/unit/core_python/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
@description: $END$
"""

from unittest.mock import Mock, patch

import pytest
import requests
import responses
from requests.adapters import HTTPAdapter
from responses import registries
from urllib3.util.retry import Retry

from swanlab.core_python import create_session
from swanlab.core_python.session import TimeoutHTTPAdapter, DEFAULT_TIMEOUT
from swanlab.package import get_package_version


Expand Down Expand Up @@ -101,3 +107,162 @@ def request_callback(request):

# 验证合并而非覆盖(两个头都存在)
assert len(captured_headers) >= 2


# ========== TimeoutHTTPAdapter Tests ==========


def test_timeout_adapter_initialization_with_timeout():
"""
测试TimeoutHTTPAdapter初始化时正确设置timeout参数
"""
timeout_value = 30
adapter = TimeoutHTTPAdapter(timeout=timeout_value)

assert adapter.timeout == timeout_value


def test_timeout_adapter_initialization_without_timeout():
"""
测试TimeoutHTTPAdapter初始化时未提供timeout参数的情况
"""
adapter = TimeoutHTTPAdapter()

assert adapter.timeout is None


def test_timeout_adapter_initialization_with_other_params():
"""
测试TimeoutHTTPAdapter初始化时传递其他HTTPAdapter参数
"""
retry = Retry(total=3)
adapter = TimeoutHTTPAdapter(max_retries=retry, timeout=45)

assert adapter.timeout == 45
assert adapter.max_retries == retry


def test_timeout_adapter_uses_default_timeout():
"""
测试TimeoutHTTPAdapter在未显式指定timeout时使用默认timeout
"""
test_url = "https://api.example.com/timeout-test"

# 创建adapter并挂载到session
adapter = TimeoutHTTPAdapter(timeout=25)

session = requests.Session()
session.mount("https://", adapter)

# 创建一个PreparedRequest来测试send方法
req = requests.Request('GET', test_url)
prepared = session.prepare_request(req)

# 直接调用adapter的send方法,不传timeout
# 我们期望adapter会自动添加timeout=25
with patch.object(HTTPAdapter, 'send') as mock_parent_send:
mock_parent_send.return_value = Mock(status_code=200, text="OK")

adapter.send(prepared)

# 验证父类的send方法被调用时包含了timeout参数
call_kwargs = mock_parent_send.call_args[1]
assert 'timeout' in call_kwargs
assert call_kwargs['timeout'] == 25


def test_timeout_adapter_respects_explicit_timeout():
"""
测试TimeoutHTTPAdapter在显式指定timeout时覆盖默认timeout
"""
test_url = "https://api.example.com/explicit-timeout"

# 创建adapter,设置默认timeout为30
adapter = TimeoutHTTPAdapter(timeout=30)

session = requests.Session()
session.mount("https://", adapter)

# 创建一个PreparedRequest来测试send方法
req = requests.Request('GET', test_url)
prepared = session.prepare_request(req)

# 直接调用adapter的send方法,显式传timeout=10
with patch.object(HTTPAdapter, 'send') as mock_parent_send:
mock_parent_send.return_value = Mock(status_code=200, text="OK")

adapter.send(prepared, timeout=10)

# 验证父类的send方法被调用时使用了显式指定的timeout=10
call_kwargs = mock_parent_send.call_args[1]
assert 'timeout' in call_kwargs
assert call_kwargs['timeout'] == 10


def test_timeout_adapter_with_none_timeout():
"""
测试TimeoutHTTPAdapter当timeout为None时不注入超时
"""
test_url = "https://api.example.com/none-timeout"

# 创建adapter,timeout为None
adapter = TimeoutHTTPAdapter(timeout=None)

session = requests.Session()
session.mount("https://", adapter)

# 创建一个PreparedRequest来测试send方法
req = requests.Request('GET', test_url)
prepared = session.prepare_request(req)

# 直接调用adapter的send方法,不传timeout
with patch.object(HTTPAdapter, 'send') as mock_parent_send:
mock_parent_send.return_value = Mock(status_code=200, text="OK")

adapter.send(prepared)

# 验证父类的send方法被调用时不应包含timeout参数
call_kwargs = mock_parent_send.call_args[1]
assert 'timeout' not in call_kwargs


@responses.activate
def test_create_session_uses_default_timeout():
"""
测试create_session创建的会话使用DEFAULT_TIMEOUT
"""
test_url = "https://api.example.com/session-timeout"

responses.add(responses.GET, test_url, body="OK", status=200)

session = create_session()

# 获取adapter
adapter = session.get_adapter(test_url)
assert isinstance(adapter, TimeoutHTTPAdapter)
assert adapter.timeout == DEFAULT_TIMEOUT

# 创建一个PreparedRequest来测试send方法
req = requests.Request('GET', test_url)
prepared = session.prepare_request(req)

# 验证实际请求使用该timeout
with patch.object(HTTPAdapter, 'send') as mock_parent_send:
mock_parent_send.return_value = Mock(status_code=200, text="OK")

adapter.send(prepared)

# 验证使用了DEFAULT_TIMEOUT
call_kwargs = mock_parent_send.call_args[1]
assert 'timeout' in call_kwargs
assert call_kwargs['timeout'] == DEFAULT_TIMEOUT


def test_timeout_adapter_inherits_from_httpAdapter():
"""
测试TimeoutHTTPAdapter正确继承自HTTPAdapter
"""
adapter = TimeoutHTTPAdapter(timeout=20)

assert isinstance(adapter, HTTPAdapter)
assert isinstance(adapter, TimeoutHTTPAdapter)
16 changes: 9 additions & 7 deletions tutils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
# ---------------------------------- 检查swanboard、swankit包的版本号与当前系统是否一致 ----------------------------------

swanboard = subprocess.run("pip show swanboard", shell=True, capture_output=True).stdout.decode()
swanboard_version = [i.split(": ")[1] for i in swanboard.split("\n") if i.startswith("Version")][0].split("\r")[0]
with open(os.path.join(swanlab_dir, "requirements.txt"), "r") as f:
packages = f.read().split("\n")
packages = [x for x in packages if "swanboard" in x]
for i in packages:
if "swanboard" in i and swanboard_version not in i:
raise Exception(f"swanboard过时,运行 pip install -r requirements.txt 进行更新.")
swanboard_versions = [i.split(':', 1)[1].strip() for i in swanboard.splitlines() if i.startswith('Version:')]
if swanboard_versions:
swanboard_version = swanboard_versions[0]
with open(os.path.join(swanlab_dir, "requirements.txt"), "r") as f:
packages = f.read().split("\n")
packages = [x for x in packages if "swanboard" in x]
for i in packages:
if "swanboard" in i and swanboard_version not in i:
raise Exception(f"swanboard过时,运行 pip install -r requirements.txt 进行更新.")

# ---------------------------------- 检查是否跳过云测试 ----------------------------------
load_dotenv(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env"))
Expand Down