Skip to content

Commit

Permalink
Merge branch 'rs/org-quota-support' into rs/for-alpha-release
Browse files Browse the repository at this point in the history
  • Loading branch information
romasku committed Nov 16, 2021
2 parents 169adf2 + 63ad7eb commit 2dd7f4c
Show file tree
Hide file tree
Showing 4 changed files with 314 additions and 11 deletions.
116 changes: 114 additions & 2 deletions neuro_admin_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,16 +735,28 @@ def _parse_org_cluster(
return OrgCluster(
cluster_name=cluster_name,
org_name=payload["org_name"],
balance=self._parse_balance(payload.get("balance")),
quota=self._parse_quota(payload.get("quota")),
)

async def create_org_cluster(
self,
cluster_name: str,
org_name: str,
quota: Quota = Quota(),
balance: Balance = Balance(),
) -> OrgCluster:
payload = {
payload: Dict[str, Any] = {
"org_name": org_name,
"quota": {},
"balance": {},
}
if quota.total_running_jobs is not None:
payload["quota"]["total_running_jobs"] = quota.total_running_jobs
if balance.credits is not None:
payload["balance"]["credits"] = str(balance.credits)
if balance.spent_credits is not None:
payload["balance"]["spent_credits"] = str(balance.spent_credits)
async with self._request(
"POST",
f"clusters/{cluster_name}/orgs",
Expand Down Expand Up @@ -780,9 +792,19 @@ async def get_org_cluster(
return self._parse_org_cluster(cluster_name, raw_data)

async def update_org_cluster(self, org_cluster: OrgCluster) -> OrgCluster:
payload = {
payload: Dict[str, Any] = {
"org_name": org_cluster.org_name,
"quota": {},
"balance": {},
}
if org_cluster.quota.total_running_jobs is not None:
payload["quota"][
"total_running_jobs"
] = org_cluster.quota.total_running_jobs
if org_cluster.balance.credits is not None:
payload["balance"]["credits"] = str(org_cluster.balance.credits)
if org_cluster.balance.spent_credits is not None:
payload["balance"]["spent_credits"] = str(org_cluster.balance.spent_credits)
async with self._request(
"PUT",
f"clusters/{org_cluster.cluster_name}/orgs/{org_cluster.org_name}",
Expand All @@ -803,6 +825,96 @@ async def delete_org_cluster(
) as resp:
resp.raise_for_status()

async def update_org_cluster_quota(
self,
cluster_name: str,
org_name: str,
quota: Quota,
*,
idempotency_key: Optional[str] = None,
) -> OrgCluster:
payload = {"quota": {"total_running_jobs": quota.total_running_jobs}}
params = {}
if idempotency_key:
params["idempotency_key"] = idempotency_key
async with self._request(
"PATCH",
f"clusters/{cluster_name}/orgs/{org_name}/quota",
json=payload,
params=params,
) as resp:
resp.raise_for_status()
raw_org_cluster = await resp.json()
return self._parse_org_cluster(cluster_name, raw_org_cluster)

async def update_org_cluster_quota_by_delta(
self,
cluster_name: str,
org_name: str,
delta: Quota,
*,
idempotency_key: Optional[str] = None,
) -> OrgCluster:
payload = {"additional_quota": {"total_running_jobs": delta.total_running_jobs}}
params = {}
if idempotency_key:
params["idempotency_key"] = idempotency_key
async with self._request(
"PATCH",
f"clusters/{cluster_name}/orgs/{org_name}/quota",
json=payload,
params=params,
) as resp:
resp.raise_for_status()
raw_org_cluster = await resp.json()
return self._parse_org_cluster(cluster_name, raw_org_cluster)

async def update_org_cluster_balance(
self,
cluster_name: str,
org_name: str,
credits: Optional[Decimal],
*,
idempotency_key: Optional[str] = None,
) -> OrgCluster:
payload = {
"credits": str(credits) if credits else None,
}
params = {}
if idempotency_key:
params["idempotency_key"] = idempotency_key
async with self._request(
"PATCH",
f"clusters/{cluster_name}/orgs/{org_name}/balance",
json=payload,
params=params,
) as resp:
resp.raise_for_status()
raw_org_cluster = await resp.json()
return self._parse_org_cluster(cluster_name, raw_org_cluster)

async def update_org_cluster_balance_by_delta(
self,
cluster_name: str,
org_name: str,
delta: Decimal,
*,
idempotency_key: Optional[str] = None,
) -> OrgCluster:
payload = {"additional_credits": str(delta)}
params = {}
if idempotency_key:
params["idempotency_key"] = idempotency_key
async with self._request(
"PATCH",
f"clusters/{cluster_name}/orgs/{org_name}/balance",
json=payload,
params=params,
) as resp:
resp.raise_for_status()
raw_org_cluster = await resp.json()
return self._parse_org_cluster(cluster_name, raw_org_cluster)

def _parse_org_payload(self, payload: Dict[str, Any]) -> Org:
return Org(
name=payload["name"],
Expand Down
14 changes: 8 additions & 6 deletions neuro_admin_client/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,6 @@ class OrgUserWithInfo(OrgUser):
user_info: UserInfo


@dataclass(frozen=True)
class OrgCluster:
org_name: str
cluster_name: str


@dataclass(frozen=True)
class Balance:
credits: Optional[Decimal] = None
Expand All @@ -80,6 +74,14 @@ class Quota:
total_running_jobs: Optional[int] = None


@dataclass(frozen=True)
class OrgCluster:
org_name: str
cluster_name: str
balance: Balance
quota: Quota


@unique
class ClusterUserRoleType(str, Enum):
ADMIN = "admin"
Expand Down
108 changes: 107 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,18 +527,38 @@ async def handle_org_user_list(
return aiohttp.web.json_response(resp)

def _serialize_org_cluster(self, org_cluster: OrgCluster) -> Dict[str, Any]:
return {
res: Dict[str, Any] = {
"org_name": org_cluster.org_name,
"quota": {},
"balance": {
"spent_credits": str(org_cluster.balance.spent_credits),
},
}
if org_cluster.quota.total_running_jobs is not None:
res["quota"]["total_running_jobs"] = org_cluster.quota.total_running_jobs
if org_cluster.balance.credits is not None:
res["balance"]["credits"] = str(org_cluster.balance.credits)
return res

async def handle_org_cluster_post(
self, request: aiohttp.web.Request
) -> aiohttp.web.Response:
cluster_name = request.match_info["cname"]
payload = await request.json()
credits_raw = payload.get("balance", {}).get("credits")
spend_credits_raw = payload.get("balance", {}).get("spend_credits_raw")
new_org_cluster = OrgCluster(
cluster_name=cluster_name,
org_name=payload["org_name"],
quota=Quota(
total_running_jobs=payload.get("quota", {}).get("total_running_jobs")
),
balance=Balance(
credits=Decimal(credits_raw) if credits_raw else None,
spent_credits=Decimal(spend_credits_raw)
if spend_credits_raw
else Decimal(0),
),
)
self.org_clusters.append(new_org_cluster)
return aiohttp.web.json_response(
Expand All @@ -553,9 +573,20 @@ async def handle_org_cluster_put(
cluster_name = request.match_info["cname"]
org_name = request.match_info["oname"]
payload = await request.json()
credits_raw = payload.get("balance", {}).get("credits")
spend_credits_raw = payload.get("balance", {}).get("spend_credits_raw")
new_org_cluster = OrgCluster(
cluster_name=cluster_name,
org_name=payload["org_name"],
quota=Quota(
total_running_jobs=payload.get("quota", {}).get("total_running_jobs")
),
balance=Balance(
credits=Decimal(credits_raw) if credits_raw else None,
spent_credits=Decimal(spend_credits_raw)
if spend_credits_raw
else Decimal(0),
),
)
assert new_org_cluster.org_name == org_name
self.org_clusters = [
Expand Down Expand Up @@ -610,6 +641,72 @@ async def handle_org_cluster_list(
]
return aiohttp.web.json_response(resp)

async def handle_org_cluster_patch_quota(
self, request: aiohttp.web.Request
) -> aiohttp.web.Response:
cluster_name = request.match_info["cname"]
org_name = request.match_info["oname"]
payload = await request.json()

for index, org_cluster in enumerate(self.org_clusters):
if (
org_cluster.cluster_name == cluster_name
and org_cluster.org_name == org_name
):
quota = org_cluster.quota
if "quota" in payload:
quota = replace(
quota,
total_running_jobs=payload["quota"].get("total_running_jobs"),
)
if (
"additional_quota" in payload
and quota.total_running_jobs is not None
):
quota = replace(
quota,
total_running_jobs=quota.total_running_jobs
+ payload["additional_quota"].get("total_running_jobs"),
)
org_cluster = replace(org_cluster, quota=quota)
self.org_clusters[index] = org_cluster
return aiohttp.web.json_response(
self._serialize_org_cluster(org_cluster)
)
raise aiohttp.web.HTTPNotFound

async def handle_org_cluster_patch_balance(
self, request: aiohttp.web.Request
) -> aiohttp.web.Response:
cluster_name = request.match_info["cname"]
org_name = request.match_info["oname"]
payload = await request.json()

for index, org_cluster in enumerate(self.org_clusters):
if (
org_cluster.cluster_name == cluster_name
and org_cluster.org_name == org_name
):
balance = org_cluster.balance
if "credits" in payload:
credits = (
Decimal(payload["credits"]) if payload["credits"] else None
)
balance = replace(balance, credits=credits)
if payload.get("additional_credits") and balance.credits is not None:
additional_credits = Decimal(payload["additional_credits"])
balance = replace(
balance, credits=balance.credits + additional_credits
)
org_cluster = replace(org_cluster, balance=balance)
self.org_clusters[index] = org_cluster
return aiohttp.web.json_response(
self._serialize_org_cluster(
org_cluster,
)
)
raise aiohttp.web.HTTPNotFound


@pytest.fixture
async def mock_admin_server(
Expand Down Expand Up @@ -758,6 +855,15 @@ def _create_app() -> aiohttp.web.Application:
"/api/v1/clusters/{cname}/orgs/{oname}/users/{uname}/spending",
admin_server.handle_cluster_user_add_spending,
),
# Org quota patch
aiohttp.web.patch(
"/api/v1/clusters/{cname}/orgs/{oname}/balance",
admin_server.handle_org_cluster_patch_balance,
),
aiohttp.web.patch(
"/api/v1/clusters/{cname}/orgs/{oname}/quota",
admin_server.handle_org_cluster_patch_quota,
),
)
)
return app
Expand Down
Loading

0 comments on commit 2dd7f4c

Please sign in to comment.