diff --git a/sdk/python/kfp/compiler/_op_to_template.py b/sdk/python/kfp/compiler/_op_to_template.py index ae7bf7a5c9e..ad72d690e11 100644 --- a/sdk/python/kfp/compiler/_op_to_template.py +++ b/sdk/python/kfp/compiler/_op_to_template.py @@ -234,6 +234,10 @@ def _op_to_template(op: BaseOp): if processed_op.tolerations: template['tolerations'] = processed_op.tolerations + # affinity + if processed_op.affinity: + template['affinity'] = K8sHelper.convert_k8s_obj_to_json(processed_op.affinity) + # metadata if processed_op.pod_annotations or processed_op.pod_labels: template['metadata'] = {} diff --git a/sdk/python/kfp/dsl/_container_op.py b/sdk/python/kfp/dsl/_container_op.py index 09d9230f931..bc2f9ddbd6f 100644 --- a/sdk/python/kfp/dsl/_container_op.py +++ b/sdk/python/kfp/dsl/_container_op.py @@ -17,7 +17,7 @@ from typing import Any, Dict, List, TypeVar, Union, Callable, Optional, Sequence from argo.models import V1alpha1ArtifactLocation -from kubernetes.client import V1Toleration +from kubernetes.client import V1Toleration, V1Affinity from kubernetes.client.models import ( V1Container, V1EnvVar, V1EnvFromSource, V1SecurityContext, V1Probe, V1ResourceRequirements, V1VolumeDevice, V1VolumeMount, V1ContainerPort, @@ -721,6 +721,7 @@ def __init__(self, self.node_selector = {} self.volumes = [] self.tolerations = [] + self.affinity = {} self.pod_annotations = {} self.pod_labels = {} self.num_retries = 0 @@ -793,13 +794,29 @@ def add_toleration(self, tolerations: V1Toleration): """Add K8s tolerations Args: - volume: Kubernetes toleration + tolerations: Kubernetes toleration For detailed spec, check toleration definition https://github.com/kubernetes-client/python/blob/master/kubernetes/client/models/v1_toleration.py """ self.tolerations.append(tolerations) return self + def add_affinity(self, affinity: V1Affinity): + """Add K8s Affinity + Args: + affinity: Kubernetes affinity + For detailed spec, check affinity definition + https://github.com/kubernetes-client/python/blob/master/kubernetes/client/models/v1_affinity.py + example: V1Affinity( + node_affinity=V1NodeAffinity( + required_during_scheduling_ignored_during_execution=V1NodeSelector( + node_selector_terms=[V1NodeSelectorTerm( + match_expressions=[V1NodeSelectorRequirement( + key='beta.kubernetes.io/instance-type', operator='In', values=['p2.xlarge'])])]))) + """ + self.affinity = affinity + return self + def add_node_selector_constraint(self, label_name, value): """Add a constraint for nodeSelector. Each constraint is a key-value pair label. For the container to be eligible to run on a node, the node must have each of the constraints appeared diff --git a/sdk/python/tests/compiler/compiler_tests.py b/sdk/python/tests/compiler/compiler_tests.py index b1afb93e177..52677c93701 100644 --- a/sdk/python/tests/compiler/compiler_tests.py +++ b/sdk/python/tests/compiler/compiler_tests.py @@ -28,7 +28,8 @@ from kfp.dsl._component import component from kfp.dsl import ContainerOp, pipeline from kfp.dsl.types import Integer, InconsistentTypeException -from kubernetes.client import V1Toleration +from kubernetes.client import V1Toleration, V1Affinity, V1NodeSelector, V1NodeSelectorRequirement, V1NodeSelectorTerm, \ + V1NodeAffinity def some_op(): @@ -356,6 +357,36 @@ def my_pipeline(): self.assertEqual(template['retryStrategy']['limit'], number_of_retries) + def test_affinity(self): + """Test affinity functionality.""" + exp_affinity = { + 'affinity': { + 'nodeAffinity': { + 'requiredDuringSchedulingIgnoredDuringExecution': { + 'nodeSelectorTerms': [ + {'matchExpressions': [ + { + 'key': 'beta.kubernetes.io/instance-type', + 'operator': 'In', + 'values': ['p2.xlarge']} + ] + }] + }} + } + } + def my_pipeline(): + affinity = V1Affinity( + node_affinity=V1NodeAffinity( + required_during_scheduling_ignored_during_execution=V1NodeSelector( + node_selector_terms=[V1NodeSelectorTerm( + match_expressions=[V1NodeSelectorRequirement( + key='beta.kubernetes.io/instance-type', operator='In', values=['p2.xlarge'])])]))) + some_op().add_affinity(affinity) + + workflow = kfp.compiler.Compiler()._compile(my_pipeline) + + self.assertEqual(workflow['spec']['templates'][1]['affinity'], exp_affinity['affinity']) + def test_py_image_pull_secrets(self): """Test pipeline imagepullsecret.""" self._test_sample_py_compile_yaml('imagepullsecrets')