Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 34 additions & 38 deletions swanlab/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def __init__(self, api_key: Optional[str] = None, host: Optional[str] = None, we
self._web_host = self._login_info.web_host
self._login_user = self._login_info.username

@self_hosted("root")
def users(self) -> Users:
"""
超级管理员获取所有用户
:return: User 实例,可对当前/指定用户进行操作
"""
return Users(self._client, login_user=self._login_user)

def user(self, username: str = None) -> User:
"""
获取用户实例,用于操作用户相关信息
Expand All @@ -60,18 +68,34 @@ def user(self, username: str = None) -> User:
"""
return User(client=self._client, login_user=self._login_user, username=username)

@self_hosted("root")
def users(self) -> Users:
def workspaces(
self,
username: str = None,
):
"""
超级管理员获取所有用户
:return: User 实例,可对当前/指定用户进行操作
获取当前登录用户的工作空间迭代器
当username为其他用户时,可以作为visitor访问其工作空间
"""
return Users(self._client, login_user=self._login_user)
if username is None:
username = self._login_user
return Workspaces(self._client, username=username)

def workspace(
self,
username: str = None,
):
"""
获取当前登录用户的工作空间
"""
if username is None:
username = self._login_user
data = get_workspace_info(self._client, path=username)
return Workspace(self._client, data=data, web_host=self._web_host, login_info=self._login_info)

def projects(
self,
path: str,
sort: Optional[List[str]] = None,
sort: Optional[str] = None,
search: Optional[str] = None,
detail: Optional[bool] = True,
) -> Projects:
Expand Down Expand Up @@ -102,7 +126,7 @@ def project(
:return: Project 实例,单个项目的信息
"""
data = get_project_info(self._client, path=path)
return Project(self._client, web_host=self._web_host, data=data)
return Project(self._client, web_host=self._web_host, data=data, login_info=self._login_info)

def runs(self, path: str, filters: Dict[str, object] = None) -> Experiments:
"""
Expand All @@ -122,46 +146,18 @@ def run(
:param path: 实验路径,格式为 'username/project/run_id'
:return: Experiment 实例,包含实验信息
"""
# TODO: 待后端完善后替换成专用的接口
if len(path.split('/')) != 3:
raise ValueError(f"User's {path} is invaded. Correct path should be like 'username/project/run_id'")
_data = get_single_experiment(self._client, path=path)
data = get_single_experiment(self._client, path=path)
proj_path = path.rsplit('/', 1)[0]
data = get_project_experiments(
self._client, path=proj_path, filters={'name': _data['name'], 'created_at': _data['createdAt']}
)
return Experiment(
self._client,
data=data[0],
data=data,
path=proj_path,
web_host=self._web_host,
login_user=self._login_user,
line_count=1,
)

def workspaces(
self,
username: str = None,
):
"""
获取当前登录用户的工作空间迭代器
当username为其他用户时,可以作为visitor访问其工作空间
"""
if username is None:
username = self._login_user
return Workspaces(self._client, username=username)

def workspace(
self,
username: str = None,
):
"""
获取当前登录用户的工作空间
"""
if username is None:
username = self._login_user
data = get_workspace_info(self._client, path=username)
return Workspace(self._client, data=data)


__all__ = ["Api", "OpenApi"]
__all__ = ["Api", "OpenApi"]
130 changes: 86 additions & 44 deletions swanlab/api/experiment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,47 @@
from swanlab.log import swanlog


class Profile:
"""Experiment profile containing config, metadata, requirements, and conda info."""

def __init__(self, data: Dict):
self._data = data

@staticmethod
def _clean_field(value: Any) -> Any:
"""Recursively clean config field, removing desc/sort and keeping value."""
if isinstance(value, dict):
if 'value' in value:
# Standard format: {'desc': ..., 'sort': ..., 'value': ...}
return Profile._clean_field(value['value'])
else:
# Nested dict without standard format, clean recursively
return {k: Profile._clean_field(v) for k, v in value.items()}
elif isinstance(value, list):
return [Profile._clean_field(item) for item in value]
return value

@property
def config(self) -> Dict:
"""Experiment configuration (cleaned, without desc/sort fields)."""
raw_config = self._data.get('config', {})
return {k: Profile._clean_field(v) for k, v in raw_config.items()}

@property
def metadata(self) -> Dict:
"""Experiment metadata."""
return self._data.get('metadata', {})

@property
def requirements(self) -> str:
"""Python requirements."""
return self._data.get('requirements', '')

@property
def conda(self) -> str:
"""Conda environment."""
return self._data.get('conda', '')


class Experiment:
def __init__(
Expand All @@ -31,14 +72,21 @@ def name(self) -> str:
"""
Experiment name.
"""
return self._data['name']
return self._data.get('name', '')

@property
def id(self) -> str:
"""
Experiment CUID.
"""
return self._data['cuid']
return self._data.get('cuid', '')

@property
def path(self) -> str:
"""
Experiment path in format 'username/project/id'.
"""
return f"{self._path}/{self.id}"

@property
def url(self) -> str:
Expand All @@ -52,83 +100,71 @@ def created_at(self) -> str:
"""
Experiment creation timestamp
"""
return self._data['createdAt']
return self._data.get('createdAt', '')

@property
def description(self) -> str:
def finished_at(self) -> str:
"""
Experiment description.
Experiment finished timestamp
"""
return self._data['description']
return self._data.get('finishedAt', '')

@property
def labels(self) -> List[Label]:
def profile(self) -> Profile:
"""
List of Label attached to this experiment.
Experiment profile containing config, metadata, requirements, and conda.
"""
return [Label(label['name']) for label in self._data['labels']]
return Profile(self._data.get('profile', {}))

@property
def state(self) -> str:
def show(self) -> bool:
"""
Experiment state.
Whether the experiment is visible.
"""
return self._data['state']
return self._data.get('show', True)

@property
def group(self) -> str:
"""
Experiment group.
"""
return self._data['cluster']

@property
def job(self) -> str:
def description(self) -> str:
"""
Experiment job type.
Experiment description.
"""
return self._data['job']
return self._data.get('description', '')

@property
def user(self) -> User:
def labels(self) -> List[Label]:
"""
Experiment user.
List of Label attached to this experiment.
"""
return User(client=self._client, login_user=self._login_user, username=self._data['user']['username'])
return [Label(label['name']) for label in self._data.get('labels', [])]

@property
def metric_keys(self) -> List[str]:
def state(self) -> str:
"""
List of metric keys.
Experiment state.
"""
return list(self.summary.keys())
return self._data.get('state', '')

@property
def history_line_count(self) -> int:
def group(self) -> str:
"""
The number of historical experiments in this project.
Experiment group.
"""
return self._line_count
return self._data.get('cluster', '')

@property
def root_exp_id(self) -> str:
def job_type(self) -> str:
"""
Root experiment cuid. If the experiment is a root experiment, it will be None.
Experiment job type.
"""
return self._data['rootExpId']
return self._data.get('job', '')

@property
def root_pro_id(self) -> str:
def user(self) -> User:
"""
Root project cuid. If the experiment is a root experiment, it will be None.
Experiment user.
"""
return self._data['rootProId']

def json(self):
"""
JSON-serializable dict of all @property values.
"""
return get_properties(self)
username = self._data.get('user', {}).get('username', '')
return User(client=self._client, login_user=self._login_user, username=username)

def metrics(self, keys: List[str] = None, x_axis: str = None, sample: int = None, pandas: bool = True) -> Any:
"""
Expand Down Expand Up @@ -217,6 +253,12 @@ def strip_suffix(col, suffix="_step"):
result_df = result_df.head(sample)

return result_df

def json(self):
"""
JSON-serializable dict of all @property values.
"""
return get_properties(self)


__all__ = ['Experiment']
__all__ = ['Experiment', 'Profile']
23 changes: 19 additions & 4 deletions swanlab/api/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
"""

from functools import cached_property
from typing import List, Dict
from typing import List, Dict, Optional

from swanlab.api.experiments import Experiments
from swanlab.api.utils import Label, get_properties
from swanlab.api.workspace import Workspace
from swanlab.core_python.api.type import ProjectType
from swanlab.core_python.api.user import get_workspace_info
from swanlab.core_python.auth.providers.api_key import LoginInfo
from swanlab.core_python.client import Client


Expand All @@ -20,10 +22,13 @@ class Project:
Representing a single project with some of its properties.
"""

def __init__(self, client: Client, *, web_host: str, data: ProjectType) -> None:
def __init__(
self, client: Client, *, web_host: str, data: ProjectType, login_info: Optional[LoginInfo] = None
) -> None:
self._client = client
self._web_host = web_host
self._data = data
self._login_info = login_info

@property
def name(self) -> str:
Expand Down Expand Up @@ -80,7 +85,7 @@ def workspace(self) -> Workspace:
Project workspace object.
"""
data = get_workspace_info(self._client, path=self._data["group"]["username"])
return Workspace(self._client, data=data)
return Workspace(self._client, data=data, web_host=self._web_host, login_info=self._login_info)

@property
def labels(self) -> List[Label]:
Expand All @@ -97,8 +102,18 @@ def count(self) -> Dict[str, int]:
"""
return self._data['_count']

def runs(self, filters: Optional[Dict[str, object]] = None) -> Experiments:
"""
Get all experiments in this project.
:param filters: Filter conditions for experiments, optional
:return: Experiments instance, iterable to get experiment information
"""
if self._login_info is None:
raise RuntimeError("login_info is required to access runs. Use api.project() instead of creating Project directly.")
return Experiments(self._client, path=self._data['path'], login_info=self._login_info, filters=filters)

def json(self):
"""
JSON-serializable dict of all @property values.
"""
return get_properties(self)
return get_properties(self)
2 changes: 1 addition & 1 deletion swanlab/api/projects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
*,
web_host: str,
path: str,
sort: Optional[List[str]] = None,
sort: Optional[str] = None,
search: Optional[str] = None,
detail: Optional[bool] = True,
) -> None:
Expand Down
Loading