Skip to content

Commit

Permalink
SDK - Components - Fixed serialization of lists and dicts containing …
Browse files Browse the repository at this point in the history
…`PipelineParam` items (#2212)

Fixes #2206
The issue is fixed for both `JSON`-based and `str()`-based serialization.
  • Loading branch information
Ark-kun authored and k8s-ci-robot committed Sep 25, 2019
1 parent 62a4f6d commit 98fd6c8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
7 changes: 6 additions & 1 deletion sdk/python/kfp/components/_data_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ def _deserialize_bool(s) -> bool:

def _serialize_json(obj) -> str:
import json
return json.dumps(obj)
def default_serializer(obj):
if hasattr(obj, 'to_struct'):
return obj.to_struct()
else:
raise TypeError("Object of type '%s' is not JSON serializable and does not have .to_struct() method." % obj.__class__.__name__)
return json.dumps(obj, default=default_serializer)


def _serialize_base64_pickle(obj) -> str:
Expand Down
8 changes: 7 additions & 1 deletion sdk/python/kfp/dsl/_pipeline_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,13 @@ def __str__(self):
return '{{pipelineparam:op=%s;name=%s}}' % (op_name, self.name)

def __repr__(self):
return str({self.__class__.__name__: self.__dict__})
# return str({self.__class__.__name__: self.__dict__})
# We make repr return the placeholder string so that if someone uses str()-based serialization of complex objects containing `PipelineParam` it works properly (e.g. str([1, 2, 3, kfp.dsl.PipelineParam("aaa"), 4, 5, 6,]))
return str(self)

def to_struct(self):
# Used by the json serializer. Outputs a JSON-serializable representation of the object
return str(self)

def __eq__(self, other):
return ConditionOperator('==', self, other)
Expand Down
13 changes: 13 additions & 0 deletions sdk/python/tests/components/test_python_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,19 @@ def assert_values_are_same(
])


def test_handling_list_arguments_containing_pipelineparam(self):
'''Checks that lists containing PipelineParam can be properly serialized'''
def consume_list(list_param: list) -> int:
pass

import kfp
task_factory = comp.func_to_container_op(consume_list)
task = task_factory([1, 2, 3, kfp.dsl.PipelineParam("aaa"), 4, 5, 6])
full_command_line = task.command + task.arguments
for arg in full_command_line:
self.assertNotIn('PipelineParam', arg)


def test_handling_base64_pickle_arguments(self):
def assert_values_are_same(
obj1: 'Base64Pickle', # noqa: F821
Expand Down

0 comments on commit 98fd6c8

Please sign in to comment.