Skip to content

Commit

Permalink
Makes ResourceLimitCalculator support proportional quotas. (apple#350)
Browse files Browse the repository at this point in the history
* Makes ResourceLimitCalculator support proportional quotas.

* Removes `_convert_and_validate_resources`.

* Adds comments to the resource limit calculation logic.

* Fixes imports.

* Fixes quota_test.py.

* Addresses review.
  • Loading branch information
ruomingp authored Mar 3, 2024
1 parent 4e06f61 commit 0e11e0a
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 100 deletions.
32 changes: 3 additions & 29 deletions axlearn/cloud/common/quota.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

"""Utilities to retrieve quotas."""
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Protocol

import toml
from absl import logging
from tensorflow import io as tf_io

from axlearn.cloud.common.types import ProjectResourceMap, ResourceMap
Expand Down Expand Up @@ -46,37 +44,13 @@ def get_resource_limits(path: str) -> QuotaInfo:
with tf_io.gfile.GFile(path, mode="r") as f:
cfg = toml.loads(f.read())
if cfg["toml-schema"]["version"] == "1":
return _convert_and_validate_resources(
QuotaInfo(
total_resources=cfg["total_resources"],
project_resources=cfg["project_resources"],
)
return QuotaInfo(
total_resources=cfg["total_resources"],
project_resources=cfg["project_resources"],
)
raise ValueError(f"Unsupported schema version {cfg['toml-schema']['version']}")


def _convert_and_validate_resources(info: QuotaInfo) -> QuotaInfo:
# Project resources are typically expressed as percentages.
# Here we convert them to actual values. Conversion happens in-place.
total_project_resources = defaultdict(float)
for resources in info.project_resources.values():
for resource_type, fraction in resources.items():
value = fraction * float(info.total_resources.get(resource_type, 0))
total_project_resources[resource_type] += value
resources[resource_type] = value

for resource_type, total in total_project_resources.items():
limit = info.total_resources.get(resource_type, 0)
if total > limit + 0.01:
logging.warning(
"Sum of %s project resources (%s) exceeds total (%s)",
resource_type,
total,
info.total_resources[resource_type],
)
return info


def get_user_projects(path: str, user_id: str) -> List[str]:
"""Attempts to read project membership for the given user.
Expand Down
16 changes: 8 additions & 8 deletions axlearn/cloud/common/quota_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
"resource_type2": 8,
},
"project_resources": {
"team1": {"resource_type1": 0.30000001},
"team2": {"resource_type1": 0.60000001, "resource_type2": 1.0},
"team1": {"resource_type1": 0.3},
"team2": {"resource_type1": 0.6, "resource_type2": 1.0},
},
"project_membership": {
"team1": ["user1"],
Expand All @@ -40,8 +40,8 @@ def test_resource_limits(self):
QuotaInfo(
total_resources={"resource_type1": 16, "resource_type2": 8},
project_resources={
"team1": {"resource_type1": 4.80000016},
"team2": {"resource_type1": 9.60000016, "resource_type2": 8.0},
"team1": {"resource_type1": 0.3},
"team2": {"resource_type1": 0.6, "resource_type2": 1.0},
},
),
get_resource_limits(f.name),
Expand All @@ -57,8 +57,8 @@ def test_resource_limits(self):
QuotaInfo(
total_resources={"resource_type1": 16, "resource_type2": 8},
project_resources={
"team1": {"resource_type1": 12.8},
"team2": {"resource_type1": 9.60000016, "resource_type2": 8.0},
"team1": {"resource_type1": 0.8},
"team2": {"resource_type1": 0.6, "resource_type2": 1.0},
},
),
get_resource_limits(f.name),
Expand All @@ -75,8 +75,8 @@ def test_resource_limits(self):
QuotaInfo(
total_resources={"resource_type1": 16},
project_resources={
"team1": {"resource_type1": 4.80000016},
"team2": {"resource_type1": 9.60000016, "resource_type2": 0.0},
"team1": {"resource_type1": 0.3},
"team2": {"resource_type1": 0.6, "resource_type2": 1.0},
},
),
get_resource_limits(f.name),
Expand Down
140 changes: 80 additions & 60 deletions axlearn/cloud/common/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,76 +146,96 @@ def calculate(
) -> Dict[str, float]:
"""Calculates per-project limits on available resources, quotas, and demands.
We assume that `limit` and `demands` are all integers, reflecting number of resource units,
e.g., number of GPUs. The allocations will also be integers.
TODO(rpang): change the API to take integers.
Args:
limit: The total amount of available resources.
quotas: A mapping from project ids to quotas. If a project id is missing, assume
quota of 0.
quota of 0. Quotas must be non-negative, but do not have to add up to `limit`.
Available resources will be allocated proportional to quotas.
demands: A mapping from project ids to demands. If a project id is missing, assume
demand of 0.
Returns:
A mapping from project ids to resource limits.
A mapping from project ids to resource limits for each project id in `demands`.
Raises:
ValueError: if total quota exceeds `limit`.
ValueError: if any quota is negative.
"""
total_quota = sum(quotas.values())
if total_quota > limit + _EPSILON:
raise ValueError(f"Total quotas ({total_quota}) exceeds limit ({limit})")
if not quotas or not demands:
return {}

demands_within_quota = set(
project_id
for project_id, quota in quotas.items()
if demands.get(project_id, 0) <= quota + _EPSILON
)
if not demands_within_quota:
# No spare capacity from any project. Set limits to project quotas.
return {project_id: quotas.get(project_id, 0) for project_id in demands}

# Mapping from project ids to resource limits.
project_limits = {}

# There is some spare capacity. Compute the per-project limits as follows:
# (1) For each project where demand <= quota, simply set project limit to its demand;
# (2) Re-compute the project limits with the remaining capacity and projects, where:
# new_limit = limit - sum(demands of projects within quota)
# new_demands = demands of remaining projects
# new_quota = quotas of remaining projects scaled to new_limit
new_limit = limit
# Mapping from project ids to demands for projects not in `demands_within_quota`.
new_demands = {}
# Mapping from project ids to quotas for projects not in `demands_within_quota`.
remaining_quotas = {}
for project_id, demand in demands.items():
if project_id in demands_within_quota:
project_limits[project_id] = demand
new_limit -= demand
else:
new_demands[project_id] = demand
remaining_quotas[project_id] = quotas.get(project_id, 0)

if new_limit > _EPSILON and remaining_quotas:
remaining_quota_sum = sum(remaining_quotas.values())
if remaining_quota_sum == 0:
# This happens when the only projects whose demands exceed quotas are those with
# zero quotas (aka "best-effort quotas").
#
# In this case we divide new_limit evenly among the remaining projects.
new_quotas = {
project_id: new_limit / len(remaining_quotas) for project_id in remaining_quotas
}
else:
# Scale quotas by (new_limit / remaining_quota_sum).
new_quotas = {
project_id: quota * new_limit / remaining_quota_sum
for project_id, quota in remaining_quotas.items()
}
# Call `self.calculate` again with the remaining projects.
new_limits = self.calculate(limit=new_limit, quotas=new_quotas, demands=new_demands)
# Merge the results into `project_limits`.
project_limits.update(new_limits)
for project_id, quota in quotas.items():
if quota < 0:
raise ValueError(f"Negative quota for {project_id}: {quota}")

project_limits = {project_id: 0 for project_id in demands}
remaining_demands = {
project_id: demand for project_id, demand in demands.items() if demand > 0
}
# Below we take a multi-pass approach to distribute `limit` to `project_limits` according
# to `quotas` and `demands`. In each pass we compute `active_quotas` according to
# `remaining_demands` and allocate resources proportional to quotas, approximately
# `min(demand, limit * (active_quota / active_quota_sum))`, to each project.
#
# As John Peebles pointed out, this is also roughly equivalent to allocating resources in
# one pass in the ascending order of `demand / quota`.
while limit > 0 and remaining_demands:
# A project is "active" if it has some remaining demand.
active_quotas = {
project_id: quotas.get(project_id, 0)
for project_id, demand in remaining_demands.items()
if demand > 0
}
active_quota_sum = sum(active_quotas.values())
if active_quota_sum == 0:
# When only best-effort quotas remain, allocate limits evenly.
active_quotas = {project_id: 1 for project_id in remaining_demands}
active_quota_sum = sum(active_quotas.values())
logging.vlog(
1,
"limit=%s active_quotas=%s remaining_demands=%s",
limit,
active_quotas,
remaining_demands,
)
# Sort projects by descending quotas.
project_id_order = [
project_id
for _, project_id in sorted(
[(quota, project_id) for project_id, quota in active_quotas.items()],
reverse=True,
)
]

def _allocate(allocation: float, *, project_id: str) -> float:
project_limits[project_id] += allocation
remaining_demands[project_id] -= allocation
if remaining_demands[project_id] <= 0:
# Remove from `remaining_demands` if the demand is now fully met.
del remaining_demands[project_id]
return allocation

new_limit = limit
# Try to allocate resources in the order of `project_id_order`.
for project_id in project_id_order:
if project_id not in remaining_demands:
continue
# The limit we can allocate to `project_id` in this round is proportional to
# its active quota but no more than `new_limit`. We round the limit, assuming
# resources can only be allocated by whole units (like GPUs).
available_limit = min(
new_limit, round(limit * active_quotas[project_id] / active_quota_sum)
)
allocation = min(available_limit, remaining_demands[project_id])
logging.vlog(
2, "Allocating %s (<=%s) to '%s'", allocation, available_limit, project_id
)
new_limit -= _allocate(allocation, project_id=project_id)
if new_limit == limit:
# Allocate to the first project.
new_limit -= _allocate(limit, project_id=project_id_order[0])
limit = new_limit
return project_limits


Expand Down
55 changes: 52 additions & 3 deletions axlearn/cloud/common/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,55 @@ def test_basic(self):
class ResourceLimitCalculatorTest(absltest.TestCase):
"""Tests ResourceLimitCalculator."""

def test_proportional_quotas(self):
calculator: ResourceLimitCalculator = ResourceLimitCalculator.default_config().instantiate()
quotas = {"a": 0.4, "b": 0.2, "c": 0.1}
self.assertDictEqual(
# Quota will be allocated proportionally.
{
"a": 4,
"b": 2,
"c": 1,
},
# `quotas` do not have to add up to `limit`.
calculator.calculate(limit=7, quotas=quotas, demands={"a": 100, "b": 100, "c": 100}),
)
self.assertDictEqual(
{
"a": 5, # quota limit is rounded up in this case.
"b": 2,
"c": 1,
},
calculator.calculate(limit=8, quotas=quotas, demands={"a": 100, "b": 100, "c": 100}),
)
self.assertDictEqual(
{
"a": 5, # quota limit is rounded up in this case.
"b": 3, # quota limit is rounded up in this case.
"c": 1,
},
calculator.calculate(limit=9, quotas=quotas, demands={"a": 100, "b": 100, "c": 100}),
)
self.assertDictEqual(
{
"a": 6, # quota limit is rounded up in this case.
"b": 3, # quota limit is rounded up in this case.
"c": 1,
},
calculator.calculate(limit=10, quotas=quotas, demands={"a": 100, "b": 100, "c": 100}),
)

def test_unallocated_resources(self):
calculator: ResourceLimitCalculator = ResourceLimitCalculator.default_config().instantiate()
self.assertDictEqual(
{
"a": 1,
"b": 1,
"c": 2, # An arbitrary project gets the remaining quota.
},
calculator.calculate(limit=4, quotas={}, demands={"a": 100, "b": 100, "c": 100}),
)

def test_extreme_cases(self):
calculator: ResourceLimitCalculator = ResourceLimitCalculator.default_config().instantiate()
# Empty demands.
Expand All @@ -96,7 +145,7 @@ def test_extreme_cases(self):
)
# Empty quota.
self.assertDictEqual(
{}, calculator.calculate(limit=10, quotas={}, demands={"a": 8, "b": 2})
{"a": 8, "b": 2}, calculator.calculate(limit=10, quotas={}, demands={"a": 8, "b": 2})
)
# Demand from one project only.
self.assertDictEqual(
Expand Down Expand Up @@ -301,14 +350,14 @@ def test_init(self, dry_run: bool):
user_id="d",
project_id="project2",
creation_time=yesterday + timedelta(seconds=4),
resources={"v5": 3},
resources={"v5": 4},
),
# Should run -- within the 2.5 excess v5 quota.
"e": JobMetadata(
user_id="e",
project_id="project3",
creation_time=yesterday + timedelta(seconds=5),
resources={"v5": 2.5},
resources={"v5": 2},
),
# Should run. Even though it has no project, there is excess v3 quota.
"f": JobMetadata(
Expand Down

0 comments on commit 0e11e0a

Please sign in to comment.