Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions swanlab/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
from .deprecated import OpenApi
from .experiment import Experiment
from .experiments import Experiments
from .project import Project
from .projects import Projects
from .user import User
from .workspace import Workspace
from .workspaces import Workspaces
from ..core_python.api.project import get_project_info
from ..core_python.api.user import get_workspace_info


class Api:
Expand Down Expand Up @@ -57,14 +60,14 @@ def user(self, username: str = None) -> User:

def projects(
self,
workspace: str,
path: str,
sort: Optional[List[str]] = None,
search: Optional[str] = None,
detail: Optional[bool] = True,
) -> Projects:
"""
获取指定工作空间(组织)下的所有项目信息
:param workspace: 工作空间(组织)名称
:param path: 工作空间(组织)名称 'username'
:param sort: 排序方式,可选
:param search: 搜索关键词,可选
:param detail: 是否返回详细信息,可选
Expand All @@ -73,12 +76,24 @@ def projects(
return Projects(
self._client,
web_host=self._web_host,
workspace=workspace,
path=path,
sort=sort,
search=search,
detail=detail,
)

def project(
self,
path: str,
) -> Project:
"""
获取指定工作空间(组织)下的指定项目信息
:param path: 项目路径 'username/project'
:return: Project 实例,单个项目的信息
"""
data = get_project_info(self._client, path=path)
return Project(self._client, web_host=self._web_host, data=data)

def runs(self, path: str, filters: Dict[str, object] = None) -> Experiments:
"""
获取指定项目下的所有实验信息
Expand Down Expand Up @@ -135,7 +150,8 @@ def workspace(
"""
if username is None:
username = self._login_user
return Workspace(client=self._client, workspace=username)
data = get_workspace_info(self._client, path=username)
return Workspace(self._client, data=data)


__all__ = ["Api", "OpenApi"]
8 changes: 5 additions & 3 deletions swanlab/api/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from swanlab.api.utils import Label, get_properties
from swanlab.api.workspace import Workspace
from swanlab.core_python.api.type import ProjectType
from swanlab.core_python.api.user import get_workspace_info
from swanlab.core_python.client import Client


Expand All @@ -19,10 +20,10 @@ class Project:
Representing a single project with some of its properties.
"""

def __init__(self, client: Client, *, data: ProjectType, web_host: str) -> None:
def __init__(self, client: Client, *, web_host: str, data: ProjectType) -> None:
self._client = client
self._data = data
self._web_host = web_host
self._data = data

@property
def name(self) -> str:
Expand Down Expand Up @@ -78,7 +79,8 @@ def workspace(self) -> Workspace:
"""
Project workspace object.
"""
return Workspace(client=self._client, workspace=self._data["group"]["username"])
data = get_workspace_info(self._client, path=self._data["group"]["username"])
return Workspace(self._client, data=data)

@property
def labels(self) -> List[Label]:
Expand Down
8 changes: 4 additions & 4 deletions swanlab/api/projects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ def __init__(
client: Client,
*,
web_host: str,
workspace: str,
path: str,
sort: Optional[List[str]] = None,
search: Optional[str] = None,
detail: Optional[bool] = True,
) -> None:
self._client = client
self._web_host = web_host
self._workspace = workspace
self._path = path
self._sort = sort
self._search = search
self._detail = detail
Expand All @@ -43,15 +43,15 @@ def __iter__(self) -> Iterator[Project]:
cur_page += 1
resp: ProjResponseType = get_workspace_projects(
self._client,
workspace=self._workspace,
path=self._path,
page=cur_page,
size=20,
sort=self._sort,
search=self._search,
detail=self._detail,
)
for p in resp['list']:
yield Project(self._client, data=p, web_host=self._web_host)
yield Project(self._client, web_host=self._web_host, data=p)

if cur_page >= resp['pages']:
break
Expand Down
8 changes: 1 addition & 7 deletions swanlab/api/workspace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,11 @@
from swanlab.api.utils import get_properties
from swanlab.core_python import Client
from swanlab.core_python.api.type import WorkspaceType, RoleType
from swanlab.core_python.api.user import get_workspace_info


class Workspace:
def __init__(self, *, data: WorkspaceType = None, client: Client = None, workspace: str = None) -> None:
def __init__(self, client: Client, *, data: WorkspaceType):
self._client = client

if data is None:
if workspace is None or client is None:
raise ValueError('workspace or client cannot both None')
data = get_workspace_info(self._client, workspace=workspace)
self._data = data

@property
Expand Down
4 changes: 2 additions & 2 deletions swanlab/api/workspaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def get_all_workspaces(self, username: str = None):

def __iter__(self) -> Iterator[Workspace]:
for space in self.get_all_workspaces():
data = get_workspace_info(self._client, workspace=space)
yield Workspace(data=data)
data = get_workspace_info(self._client, path=space)
yield Workspace(self._client, data=data)


__all__ = ['Workspaces']
20 changes: 15 additions & 5 deletions swanlab/core_python/api/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from typing import Optional, List, TYPE_CHECKING

from swanlab.core_python.api.type import ProjResponseType
from swanlab.core_python.api.type import ProjResponseType, ProjectType

if TYPE_CHECKING:
from swanlab.core_python.client import Client
Expand All @@ -16,7 +16,7 @@
def get_workspace_projects(
client: "Client",
*,
workspace: str,
path: str,
page: int = 1,
size: int = 20,
sort: Optional[List[str]] = None,
Expand All @@ -26,7 +26,7 @@ def get_workspace_projects(
"""
获取指定页数和条件下的项目信息
:param client: 已登录的客户端实例
:param workspace: 工作空间名称
:param path: 工作空间名称
:param page: 页码
:param size: 每页项目数量
:param sort: 排序规则, 可选
Expand All @@ -40,8 +40,18 @@ def get_workspace_projects(
'search': search,
'detail': detail,
}
res = client.get(f"/project/{workspace}", params=dict(params))
res = client.get(f"/project/{path}", params=dict(params))
return res[0]


__all__ = ["get_workspace_projects"]
def get_project_info(client: "Client", *, path: str) -> ProjectType:
"""
获取指定路径的项目信息
:param client: 已登录的客户端实例
:param path: 项目路径 'username/project'
"""
res = client.get(f"/project/{path}")
return res[0]


__all__ = ["get_workspace_projects", "get_project_info"]
6 changes: 3 additions & 3 deletions swanlab/core_python/api/user/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def get_user_groups(client: "Client", *, username: str) -> List[GroupType]:
return res[0]


def get_workspace_info(client: "Client", *, workspace: str) -> WorkspaceType:
def get_workspace_info(client: "Client", *, path: str) -> WorkspaceType:
"""
获取指定工作空间的信息
:param client: 已登录的客户端实例
:param workspace: 工作空间名称
:param path: 工作空间名称
"""
res = client.get(f"/group/{workspace}")
res = client.get(f"/group/{path}")
return res[0]


Expand Down
2 changes: 1 addition & 1 deletion test/unit/api/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def side_effect(*args, **kwargs):
mock_projects = Projects(
MagicMock(spec=Client),
web_host=get_host_web(),
workspace='test_user',
path='test_user',
)
projects = list(mock_projects)
assert len(projects) == total
Expand Down