Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Kueue. #1754

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@
ARGO_WORKFLOWS_KUBERNETES_SECRETS = from_conf("ARGO_WORKFLOWS_KUBERNETES_SECRETS", "")
ARGO_WORKFLOWS_ENV_VARS_TO_SKIP = from_conf("ARGO_WORKFLOWS_ENV_VARS_TO_SKIP", "")

## Kueue Support
KUEUE_ENABLED = from_conf("KUEUE_ENABLED", False)
KUEUE_LOCALQUEUE_NAME = from_conf("KUEUE_LOCALQUEUE_NAME", "")

##
# Argo Events Configuration
##
Expand Down
30 changes: 26 additions & 4 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,15 +1368,34 @@ def _container_templates(self):
tmpfs_size = resources["tmpfs_size"]
tmpfs_path = resources["tmpfs_path"]
tmpfs_tempdir = resources["tmpfs_tempdir"]
# Set shared_memory to 0 if it isn't specified. This results
# in Kubernetes using it's default value when the pod is created.
shared_memory = resources.get("shared_memory", 0)

tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)

if tmpfs_enabled and tmpfs_tempdir:
env["METAFLOW_TEMPDIR"] = tmpfs_path

# Set shared_memory to 0 if it isn't specified. This results
# in Kubernetes using it's default value when the pod is created.
shared_memory = resources.get("shared_memory", 0)

kueue_enabled = resources["kueue_enabled"]
kueue_localqueue_name = resources["kueue_localqueue_name"]
kueue_annotations = {}
kueue_labels = {}
if kueue_enabled:
kueue_annotations["kueue.x-k8s.io/retriable-in-group"] = "false"
kueue_annotations["kueue.x-k8s.io/pod-group-total-count"] = str(
1
) # For now, might change with @parallel support
kueue_labels["kueue.x-k8s.io/queue-name"] = kueue_localqueue_name
kueue_labels["kueue.x-k8s.io/managed"] = "true"
kueue_labels["kueue.x-k8s.io/pod-group-name"] = (
"{{workflow.name}}-" + node.name
)
if node.is_inside_foreach:
kueue_labels["kueue.x-k8s.io/pod-group-name"] = \
kueue_labels["kueue.x-k8s.io/pod-group-name"] + \
"-{{inputs.parameters.split-index}}"
# Create a ContainerTemplate for this node. Ideally, we would have
# liked to inline this ContainerTemplate and avoid scanning the workflow
# twice, but due to issues with variable substitution, we will have to
Expand All @@ -1399,13 +1418,16 @@ def _container_templates(self):
minutes_between_retries=minutes_between_retries,
)
.metadata(
ObjectMeta().annotation("metaflow/step_name", node.name)
ObjectMeta()
.annotation("metaflow/step_name", node.name)
# Unfortunately, we can't set the task_id since it is generated
# inside the pod. However, it can be inferred from the annotation
# set by argo-workflows - `workflows.argoproj.io/outputs` - refer
# the field 'task-id' in 'parameters'
# .annotation("metaflow/task_id", ...)
.annotation("metaflow/attempt", retry_count)
.annotations(kueue_annotations)
.labels(kueue_labels)
)
# Set emptyDir volume for state management
.empty_dir_volume("out")
Expand Down
6 changes: 6 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def create_job(
tolerations=None,
labels=None,
shared_memory=None,
kueue_enabled=None,
kueue_localqueue_name=None,
):
if env is None:
env = {}
Expand All @@ -183,6 +185,8 @@ def create_job(
KubernetesClient()
.job(
generate_name="t-{uid}-".format(uid=str(uuid4())[:8]),
run_id=run_id,
task_id=task_id,
namespace=namespace,
service_account=service_account,
secrets=secrets,
Expand Down Expand Up @@ -215,6 +219,8 @@ def create_job(
tmpfs_path=tmpfs_path,
persistent_volume_claims=persistent_volume_claims,
shared_memory=shared_memory,
kueue_enabled=kueue_enabled,
kueue_localqueue_name=kueue_localqueue_name,
)
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
.environment_variable("METAFLOW_CODE_URL", code_package_url)
Expand Down
14 changes: 14 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import sys
import time
import traceback
Expand Down Expand Up @@ -108,6 +109,15 @@ def kubernetes():
multiple=False,
)
@click.option("--shared-memory", default=None, help="Size of shared memory in MiB")
@click.option(
"--kueue-enabled",
is_flag=True,
default=None,
help="Whether to use Kueue for scheduling Kubernetes jobs/pods",
)
@click.option(
"--kueue-localqueue-name", help="Name of the LocalQueue configured with kueue"
)
@click.pass_context
def step(
ctx,
Expand All @@ -134,6 +144,8 @@ def step(
persistent_volume_claims=None,
tolerations=None,
shared_memory=None,
kueue_enabled=None,
kueue_localqueue_name=None,
**kwargs
):
def echo(msg, stream="stderr", job_id=None, **kwargs):
Expand Down Expand Up @@ -248,6 +260,8 @@ def _sync_metadata():
persistent_volume_claims=persistent_volume_claims,
tolerations=tolerations,
shared_memory=shared_memory,
kueue_enabled=kueue_enabled,
kueue_localqueue_name=kueue_localqueue_name,
)
except Exception as e:
traceback.print_exc(chain=False)
Expand Down
17 changes: 17 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
KUBERNETES_TOLERATIONS,
KUBERNETES_SERVICE_ACCOUNT,
KUBERNETES_SHARED_MEMORY,
KUEUE_ENABLED,
KUEUE_LOCALQUEUE_NAME,
)
from metaflow.plugins.resources_decorator import ResourcesDecorator
from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
Expand Down Expand Up @@ -90,6 +92,10 @@ class KubernetesDecorator(StepDecorator):
volumes to the path to which the volume is to be mounted, e.g., `{'pvc-name': '/path/to/mount/on'}`.
shared_memory: int, optional
Shared memory size (in MiB) required for this step
kueue_enabled: bool, optional
Whether Kubernetes job/Argo workflow pod should submitted using Kueue
kueue_localqueue_name: str, optional
The name of the localqueue object configured in Kueue to use for submitting jobs/pods
"""

name = "kubernetes"
Expand All @@ -113,6 +119,8 @@ class KubernetesDecorator(StepDecorator):
"tmpfs_path": "/metaflow_temp",
"persistent_volume_claims": None, # e.g., {"pvc-name": "/mnt/vol", "another-pvc": "/mnt/vol2"}
"shared_memory": None,
"kueue_enabled": None,
"kueue_localqueue_name": None,
}
package_url = None
package_sha = None
Expand Down Expand Up @@ -201,6 +209,15 @@ def __init__(self, attributes=None, statically_defined=False):
if not self.attributes["shared_memory"]:
self.attributes["shared_memory"] = KUBERNETES_SHARED_MEMORY

# Process config options related to KUEUE
if self.attributes["kueue_enabled"] is None:
self.attributes["kueue_enabled"] = KUEUE_ENABLED
if (
"kueue_localqueue_name" not in self.attributes
or self.attributes["kueue_localqueue_name"] is None
):
self.attributes["kueue_localqueue_name"] = KUEUE_LOCALQUEUE_NAME

# Refer https://github.com/Netflix/metaflow/blob/master/docs/lifecycle.png
def step_init(self, flow, graph, step, decos, environment, flow_datastore, logger):
# Executing Kubernetes jobs requires a non-local datastore.
Expand Down
21 changes: 19 additions & 2 deletions metaflow/plugins/kubernetes/kubernetes_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,31 @@ def create(self):
else None
)

annotations = self._kwargs.get("annotations", {})
labels = self._kwargs.get("labels", {})

kueue_enabled = bool(self._kwargs["kueue_enabled"])
localqueue_name = self._kwargs["kueue_localqueue_name"]
if kueue_enabled:
labels["kueue.x-k8s.io/queue-name"] = localqueue_name
labels["kueue.x-k8s.io/pod-group-name"] = (
self._kwargs["run_id"]
+ "-"
+ self._kwargs["step_name"]
+ "-"
+ self._kwargs["task_id"]
)
annotations["kueue.x-k8s.io/retriable-in-group"] = "false"
annotations["kueue.x-k8s.io/pod-group-total-count"] = str(1)

self._job = client.V1Job(
api_version="batch/v1",
kind="Job",
metadata=client.V1ObjectMeta(
# Annotations are for humans
annotations=self._kwargs.get("annotations", {}),
annotations=annotations,
# While labels are for Kubernetes
labels=self._kwargs.get("labels", {}),
labels=labels,
generate_name=self._kwargs["generate_name"],
namespace=self._kwargs["namespace"], # Defaults to `default`
),
Expand Down
Loading