Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions swanlab/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from .experiments import Experiments
from .projects import Projects
from .user import User
from .users import Users
from .utils import self_hosted
from .workspace import Workspace
from .workspaces import Workspaces

Expand Down Expand Up @@ -55,6 +57,14 @@ def user(self, username: str = None) -> User:
"""
return User(client=self._client, login_user=self._login_user, username=username)

@self_hosted("root")
def users(self) -> Users:
"""
超级管理员获取所有用户
:return: User 实例,可对当前/指定用户进行操作
"""
return Users(self._client, login_user=self._login_user)

def projects(
self,
workspace: str,
Expand Down
2 changes: 1 addition & 1 deletion swanlab/api/user/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
create_api_key,
get_latest_api_key,
delete_api_key,
create_user,
)
from swanlab.core_python.api.user.self_hosted import create_user
from swanlab.core_python.client import Client


Expand Down
38 changes: 38 additions & 0 deletions swanlab/api/users/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
@author: cunyue
@file: __init__.py
@time: 2026/2/3 13:30
@description: OpenApi 中的用户对象迭代器
"""

from typing import Iterator

from swanlab.api.user import User
from swanlab.core_python import Client
from swanlab.core_python.api.user import get_users


class Users:
"""
Container for a collection of User objects.
You can iterate over the users by for-in loop.
"""

def __init__(self, client: Client, *, login_user: str) -> None:
self._client = client
self._login_user = login_user

def __iter__(self) -> Iterator[User]:
cur_page = 0
while True:
cur_page += 1
resp = get_users(
self._client,
page=cur_page,
size=20,
)
for u in resp['list']:
yield User(self._client, login_user=self._login_user, username=u['username'])

if cur_page >= resp['pages']:
break
7 changes: 5 additions & 2 deletions swanlab/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ def wrapper(self, *args, **kwargs):
raise ValueError("SwanLab self-hosted instance has expired.")

# 2. 检测用户权限(商业版root用户功能)
if identity == "root" and not self_hosted_info.get("root", False):
raise ValueError("You don't have permission to perform this action. Please login as a root user")
if identity == 'root':
if not self_hosted_info.get('root', False):
raise ValueError("You don't have permission to perform this action. Please login as a root user")
if not getattr(self, 'is_self', True):
raise ValueError('This root-only action can only be performed by the logged-in root user.')

return func(self, *args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion swanlab/core_python/api/user/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import TYPE_CHECKING, List

from swanlab.core_python.api.type import GroupType, ApiKeyType, WorkspaceType
from .self_hosted import get_self_hosted_init, create_user
from .self_hosted import get_self_hosted_init, create_user, get_users

if TYPE_CHECKING:
from swanlab.core_python.client import Client
Expand Down Expand Up @@ -79,4 +79,5 @@ def get_latest_api_key(client: "Client") -> ApiKeyType:
"get_latest_api_key",
"get_self_hosted_init",
"create_user",
"get_users",
]
14 changes: 13 additions & 1 deletion swanlab/core_python/api/user/self_hosted.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,22 @@ def get_self_hosted_init(client: "Client") -> SelfHostedInfoType:

def create_user(client: "Client", *, username: str, password: str) -> None:
"""
根用户添加用户
添加用户(私有化管理员限定)
:param client: 已登录的客户端实例
:param username: 用户名
:param password: 用户密码
"""
data = {"users": [{"username": username, "password": password}]}
client.post("/self_hosted/users", data=data)


def get_users(client: "Client", *, page: int = 1, size: int = 20):
"""
分页获取用户(管理员限定)
:param client: 已登录的客户端实例
:param page: 页码
:param size: 每页大小
"""
params = {"page": page, "size": size}
res = client.get("/self_hosted/users", params=params)
return res[0]
23 changes: 23 additions & 0 deletions test/unit/api/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import pytest

from swanlab.api.user import User
from swanlab.api.users import Users
from swanlab.core_python import Client
from swanlab.error import ApiError
from utils import create_user_data


def create_user(username=None):
Expand Down Expand Up @@ -84,3 +86,24 @@ def test_other_user():
assert other_user.generate_api_key() is None
with pytest.raises(ValueError):
assert other_user.delete_api_key(api_key='test_api_key') == False


def test_users():
"""测试能否分页获取所有用户"""
with patch('swanlab.api.users.get_users') as mock_get_users:
total = 80
page_size = 20

def side_effect(*args, **kwargs):
return create_user_data(page=kwargs.get("page", 1), total=total)

mock_get_users.side_effect = side_effect
client = MagicMock(spec=Client)
users = Users(client, login_user="test_user")

user_list = list(users)
assert len(user_list) == total
for i, user in enumerate(user_list):
assert user.username == f'user_{i}'

assert mock_get_users.call_count == (total + page_size - 1) // page_size
24 changes: 24 additions & 0 deletions test/unit/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,30 @@ def create_project_data(page: int = 1, total: int = 20) -> ProjResponseType:
}


def create_user_data(page: int = 1, total: int = 20) -> Dict:
"""
创建分页用户数据(用于模拟 get_users 的返回值)

:param page: 当前页数
:param total: 用户总数
:return: 包含 list, pages, total 等字段的字典
"""
page_size = 20
pages = (total + page_size - 1) // page_size
user_list = []

for j in range(page_size):
user_list.append({
'username': f'user_{ (page - 1) * page_size + j }'
})

return {
'list': user_list,
'pages': pages,
'total': total,
}


def create_csv_data(step_values, metric_name, metric_values):
"""
创建 CSV 格式的数据
Expand Down