Skip to content

Commit e997d42

Browse files
committed
feat(sdk) Add SemaphoreKey and MutexName fields to DSL
Signed-off-by: ddalvi <ddalvi@redhat.com> Add tests to verify setting of SemaphoreKey and MutexName fields in DSL Signed-off-by: ddalvi <ddalvi@redhat.com> Implement getter setter functions with latest Pythonic properties Signed-off-by: ddalvi <ddalvi@redhat.com>
1 parent 68c1dd7 commit e997d42

File tree

3 files changed

+126
-7
lines changed

3 files changed

+126
-7
lines changed

sdk/python/kfp/compiler/compiler_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4336,6 +4336,67 @@ def my_pipeline():
43364336
pipeline_func=my_pipeline, package_path=output_yaml)
43374337

43384338

4339+
class TestPipelineSemaphoreMutex(unittest.TestCase):
4340+
4341+
def test_pipeline_with_semaphore(self):
4342+
"""Test that pipeline config correctly sets the semaphore key."""
4343+
config = PipelineConfig()
4344+
config.semaphore_key = 'semaphore'
4345+
4346+
@dsl.pipeline(pipeline_config=config)
4347+
def my_pipeline():
4348+
task = comp()
4349+
4350+
with tempfile.TemporaryDirectory() as tempdir:
4351+
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
4352+
compiler.Compiler().compile(
4353+
pipeline_func=my_pipeline, package_path=output_yaml)
4354+
4355+
with open(output_yaml, 'r') as f:
4356+
pipeline_docs = list(yaml.safe_load_all(f))
4357+
4358+
platform_spec = None
4359+
for doc in pipeline_docs:
4360+
if 'platforms' in doc:
4361+
platform_spec = doc
4362+
break
4363+
4364+
self.assertIsNotNone(platform_spec,
4365+
'No platforms section found in compiled output')
4366+
kubernetes_spec = platform_spec['platforms']['kubernetes'][
4367+
'pipelineConfig']
4368+
self.assertEqual(kubernetes_spec['semaphoreKey'], 'semaphore')
4369+
4370+
def test_pipeline_with_mutex(self):
4371+
"""Test that pipeline config correctly sets the mutex name."""
4372+
config = PipelineConfig()
4373+
config.mutex_name = 'mutex'
4374+
4375+
@dsl.pipeline(pipeline_config=config)
4376+
def my_pipeline():
4377+
task = comp()
4378+
4379+
with tempfile.TemporaryDirectory() as tempdir:
4380+
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
4381+
compiler.Compiler().compile(
4382+
pipeline_func=my_pipeline, package_path=output_yaml)
4383+
4384+
with open(output_yaml, 'r') as f:
4385+
pipeline_docs = list(yaml.safe_load_all(f))
4386+
4387+
platform_spec = None
4388+
for doc in pipeline_docs:
4389+
if 'platforms' in doc:
4390+
platform_spec = doc
4391+
break
4392+
4393+
self.assertIsNotNone(platform_spec,
4394+
'No platforms section found in compiled output')
4395+
kubernetes_spec = platform_spec['platforms']['kubernetes'][
4396+
'pipelineConfig']
4397+
self.assertEqual(kubernetes_spec['mutexName'], 'mutex')
4398+
4399+
43394400
class ExtractInputOutputDescription(unittest.TestCase):
43404401

43414402
def test_no_descriptions(self):

sdk/python/kfp/compiler/pipeline_spec_builder.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2242,14 +2242,20 @@ def _write_kubernetes_manifest_to_file(
22422242

22432243
def _merge_pipeline_config(pipelineConfig: pipeline_config.PipelineConfig,
22442244
platformSpec: pipeline_spec_pb2.PlatformSpec):
2245+
config_dict = {}
2246+
22452247
workspace = pipelineConfig.workspace
2246-
if workspace is None:
2247-
return platformSpec
2248+
if workspace is not None:
2249+
config_dict['workspace'] = workspace.get_workspace()
22482250

2249-
json_format.ParseDict(
2250-
{'pipelineConfig': {
2251-
'workspace': workspace.get_workspace(),
2252-
}}, platformSpec.platforms['kubernetes'])
2251+
if pipelineConfig.semaphore_key is not None:
2252+
config_dict['semaphoreKey'] = pipelineConfig.semaphore_key
2253+
if pipelineConfig.mutex_name is not None:
2254+
config_dict['mutexName'] = pipelineConfig.mutex_name
2255+
2256+
if config_dict:
2257+
json_format.ParseDict({'pipelineConfig': config_dict},
2258+
platformSpec.platforms['kubernetes'])
22532259

22542260
return platformSpec
22552261

sdk/python/kfp/dsl/pipeline_config.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,57 @@ def set_kubernetes_config(self,
9696
class PipelineConfig:
9797
"""PipelineConfig contains pipeline-level config options."""
9898

99-
def __init__(self, workspace: Optional[WorkspaceConfig] = None):
99+
def __init__(self,
100+
workspace: Optional[WorkspaceConfig] = None,
101+
semaphore_key: Optional[str] = None,
102+
mutex_name: Optional[str] = None):
100103
self.workspace = workspace
104+
self._semaphore_key = semaphore_key
105+
self._mutex_name = mutex_name
106+
107+
@property
108+
def semaphore_key(self) -> Optional[str]:
109+
"""Get the semaphore key for controlling pipeline concurrency.
110+
111+
Returns:
112+
Optional[str]: The semaphore key, or None if not set.
113+
"""
114+
return self._semaphore_key
115+
116+
@semaphore_key.setter
117+
def semaphore_key(self, value: str):
118+
"""Set the semaphore key to control pipeline concurrency.
119+
120+
Pipelines with the same semaphore key will be limited to a configured maximum
121+
number of concurrent executions. This allows you to control resource usage by
122+
ensuring that only a specific number of pipelines can run simultaneously.
123+
124+
Note: A pipeline can use both semaphores and mutexes together. The pipeline
125+
will wait until all required locks are available before starting.
126+
127+
Args:
128+
value (str): The semaphore key name for controlling concurrent executions.
129+
"""
130+
self._semaphore_key = (value and value.strip()) or None
131+
132+
@property
133+
def mutex_name(self) -> Optional[str]:
134+
"""Get the mutex name for exclusive pipeline execution.
135+
136+
Returns:
137+
Optional[str]: The mutex name, or None if not set.
138+
"""
139+
return self._mutex_name
140+
141+
@mutex_name.setter
142+
def mutex_name(self, value: str):
143+
"""Set the name of the mutex to ensure mutual exclusion.
144+
145+
Pipelines with the same mutex name will only run one at a time. This ensures
146+
exclusive access to shared resources and prevents conflicts when multiple
147+
pipelines would otherwise compete for the same resources.
148+
149+
Args:
150+
value (str): Name of the mutex for exclusive pipeline execution.
151+
"""
152+
self._mutex_name = (value and value.strip()) or None

0 commit comments

Comments
 (0)