diff --git a/sdk/python/kfp/compiler/_op_to_template.py b/sdk/python/kfp/compiler/_op_to_template.py index 45a69b888480..2d3f631b1d3a 100644 --- a/sdk/python/kfp/compiler/_op_to_template.py +++ b/sdk/python/kfp/compiler/_op_to_template.py @@ -176,6 +176,10 @@ def _outputs_to_json(op: BaseOp, def _op_to_template(op: BaseOp): """Generate template given an operator inherited from BaseOp.""" + # Display name + if op.display_name: + op.add_pod_annotation('pipelines.kubeflow.org/task_display_name', op.display_name) + # NOTE in-place update to BaseOp # replace all PipelineParams with template var strings processed_op = _process_base_ops(op) @@ -270,10 +274,6 @@ def _op_to_template(op: BaseOp): template['volumes'] = [convert_k8s_obj_to_json(volume) for volume in processed_op.volumes] template['volumes'].sort(key=lambda x: x['name']) - # Display name - if processed_op.display_name: - template.setdefault('metadata', {}).setdefault('annotations', {})['pipelines.kubeflow.org/task_display_name'] = processed_op.display_name - if isinstance(op, dsl.ContainerOp) and op._metadata: template.setdefault('metadata', {}).setdefault('annotations', {})['pipelines.kubeflow.org/component_spec'] = json.dumps(op._metadata.to_dict(), sort_keys=True) diff --git a/sdk/python/tests/compiler/compiler_tests.py b/sdk/python/tests/compiler/compiler_tests.py index 8bdbde359603..ca0b541edfe7 100644 --- a/sdk/python/tests/compiler/compiler_tests.py +++ b/sdk/python/tests/compiler/compiler_tests.py @@ -624,6 +624,16 @@ def some_pipeline(): template = workflow_dict['spec']['templates'][0] self.assertEqual(template['metadata']['annotations']['pipelines.kubeflow.org/task_display_name'], 'Custom name') + def test_set_dynamic_display_name(self): + """Test a pipeline with a customized task names.""" + + def some_pipeline(custom_name): + some_op().set_display_name(custom_name) + + workflow_dict = kfp.compiler.Compiler()._compile(some_pipeline) + template = [template for template in workflow_dict['spec']['templates'] if 'container' in template][0] + self.assertNotIn('pipelineparam', template['metadata']['annotations']['pipelines.kubeflow.org/task_display_name']) + def test_set_parallelism(self): """Test a pipeline with parallelism limits.""" def some_op():