diff --git a/sdk/python/kfp/components/_data_passing.py b/sdk/python/kfp/components/_data_passing.py index cc38d7bcf21..9fcad5ba2e1 100644 --- a/sdk/python/kfp/components/_data_passing.py +++ b/sdk/python/kfp/components/_data_passing.py @@ -46,7 +46,12 @@ def _deserialize_bool(s) -> bool: def _serialize_json(obj) -> str: import json - return json.dumps(obj) + def default_serializer(obj): + if hasattr(obj, 'to_struct'): + return obj.to_struct() + else: + raise TypeError("Object of type '%s' is not JSON serializable and does not have .to_struct() method." % obj.__class__.__name__) + return json.dumps(obj, default=default_serializer) def _serialize_base64_pickle(obj) -> str: diff --git a/sdk/python/kfp/dsl/_pipeline_param.py b/sdk/python/kfp/dsl/_pipeline_param.py index 607143459f3..380e080fe95 100644 --- a/sdk/python/kfp/dsl/_pipeline_param.py +++ b/sdk/python/kfp/dsl/_pipeline_param.py @@ -191,7 +191,13 @@ def __str__(self): return '{{pipelineparam:op=%s;name=%s}}' % (op_name, self.name) def __repr__(self): - return str({self.__class__.__name__: self.__dict__}) + # return str({self.__class__.__name__: self.__dict__}) + # We make repr return the placeholder string so that if someone uses str()-based serialization of complex objects containing `PipelineParam` it works properly (e.g. str([1, 2, 3, kfp.dsl.PipelineParam("aaa"), 4, 5, 6,])) + return str(self) + + def to_struct(self): + # Used by the json serializer. Outputs a JSON-serializable representation of the object + return str(self) def __eq__(self, other): return ConditionOperator('==', self, other) diff --git a/sdk/python/tests/components/test_python_op.py b/sdk/python/tests/components/test_python_op.py index 93d9b321692..8393ab00e9e 100644 --- a/sdk/python/tests/components/test_python_op.py +++ b/sdk/python/tests/components/test_python_op.py @@ -483,6 +483,19 @@ def assert_values_are_same( ]) + def test_handling_list_arguments_containing_pipelineparam(self): + '''Checks that lists containing PipelineParam can be properly serialized''' + def consume_list(list_param: list) -> int: + pass + + import kfp + task_factory = comp.func_to_container_op(consume_list) + task = task_factory([1, 2, 3, kfp.dsl.PipelineParam("aaa"), 4, 5, 6]) + full_command_line = task.command + task.arguments + for arg in full_command_line: + self.assertNotIn('PipelineParam', arg) + + def test_handling_base64_pickle_arguments(self): def assert_values_are_same( obj1: 'Base64Pickle', # noqa: F821