Skip to content

Commit

Permalink
SDK - Lightweight - Added support for complex default values (#1696)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ark-kun authored and k8s-ci-robot committed Aug 12, 2019
1 parent c5418fd commit 7917ea4
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 10 deletions.
27 changes: 25 additions & 2 deletions sdk/python/kfp/components/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,14 @@ def _try_get_object_by_name(obj_name):
_created_task_transformation_handler.append(_dsl_bridge.create_container_op_from_task)


class _DefaultValue:
def __init__(self, value):
self.value = value

def __repr__(self):
return repr(self.value)


#TODO: Refactor the function to make it shorter
def _create_task_factory_from_component_spec(component_spec:ComponentSpec, component_filename=None, component_ref: ComponentReference = None):
name = component_spec.name or _default_component_name
Expand Down Expand Up @@ -220,7 +228,7 @@ def create_task_from_component_and_arguments(pythonic_arguments):
arguments = {
pythonic_name_to_input_name[k]: (v if isinstance(v, valid_argument_types) else str(v))
for k, v in pythonic_arguments.items()
if v is not None
if not isinstance(v, _DefaultValue) # Skipping passing arguments for optional values that have not been overridden.
}
for key in arguments:
if isinstance(arguments[key], PipelineParam):
Expand All @@ -244,7 +252,22 @@ def create_task_from_component_and_arguments(pythonic_arguments):

#Reordering the inputs since in Python optional parameters must come after required parameters
reordered_input_list = [input for input in inputs_list if input.default is None and not input.optional] + [input for input in inputs_list if not (input.default is None and not input.optional)]
input_parameters = [_dynamic.KwParameter(input_name_to_pythonic[port.name], annotation=(_try_get_object_by_name(str(port.type)) if port.type else inspect.Parameter.empty), default=port.default if port.default is not None else (None if port.optional else inspect.Parameter.empty)) for port in reordered_input_list]

def component_default_to_func_default(component_default: str, is_optional: bool):
if is_optional:
return _DefaultValue(component_default)
if component_default is not None:
return component_default
return inspect.Parameter.empty

input_parameters = [
_dynamic.KwParameter(
input_name_to_pythonic[port.name],
annotation=(_try_get_object_by_name(str(port.type)) if port.type else inspect.Parameter.empty),
default=component_default_to_func_default(port.default, port.optional),
)
for port in reordered_input_list
]
factory_function_parameters = input_parameters #Outputs are no longer part of the task factory function signature. The paths are always generated by the system.

return _dynamic.create_function_from_parameters(
Expand Down
20 changes: 12 additions & 8 deletions sdk/python/kfp/components/_python_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,14 @@ def annotation_to_type_struct(annotation):
type=type_struct,
)
if parameter.default is not inspect.Parameter.empty:
if parameter.default is None:
input_spec.optional = True
else:
input_spec.default = str(parameter.default)
input_spec.optional = True
if parameter.default is not None:
serialized_default = str(parameter.default)
if not isinstance(parameter.default, (str, int, float)):
import warnings
warnings.warn('Default value of unsupported type {} will be converted to string "{}".'.format(str(type(parameter.default)), serialized_default))
input_spec.default = serialized_default

inputs.append(input_spec)

#Analyzing the return type annotations.
Expand Down Expand Up @@ -241,6 +245,7 @@ def _func_to_component_spec(func, extra_code='', base_image=_default_base_image,

arg_parse_code_lines = [
'import argparse',
'_missing_arg = object()',
'_parser = argparse.ArgumentParser(prog={prog_repr}, description={description_repr})'.format(
prog_repr=repr(component_spec.name or ''),
description_repr=repr(component_spec.description or ''),
Expand All @@ -249,13 +254,12 @@ def _func_to_component_spec(func, extra_code='', base_image=_default_base_image,
arguments = []
for input in component_spec.inputs:
param_flag = "--" + input.name.replace("_", "-")
is_required = not input.optional #TODO: Make all parameters with default values optional in argparse so that the complex defaults can be preserved.
line = '_parser.add_argument("{param_flag}", dest="{param_var}", type={param_type}, required={is_required}, default={default_repr})'.format(
is_required = not input.optional
line = '_parser.add_argument("{param_flag}", dest="{param_var}", type={param_type}, required={is_required}, default=_missing_arg)'.format(
param_flag=param_flag,
param_var=input.name,
param_type=(input.type if input.type in ['int', 'float', 'bool'] else 'str'),
is_required=str(is_required),
default_repr=repr(str(input.default)) if input.default is not None else None,
)
arg_parse_code_lines.append(line)
if is_required:
Expand Down Expand Up @@ -284,7 +288,7 @@ def _func_to_component_spec(func, extra_code='', base_image=_default_base_image,
arguments.extend(OutputPathPlaceholder(output.name) for output in component_spec.outputs)

arg_parse_code_lines.extend([
'_parsed_args = vars(_parser.parse_args())',
'_parsed_args = {k: v for k, v in vars(_parser.parse_args()).items() if v is not _missing_arg}',
])

arg_parse_code_lines.extend([
Expand Down
21 changes: 21 additions & 0 deletions sdk/python/tests/components/test_python_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,27 @@ def assert_is_none(a, b, arg=None) -> int:
op = comp.func_to_container_op(func, output_component_file='comp.yaml')
self.helper_test_2_in_1_out_component_using_local_call(func, op)


def test_handling_complex_default_values_of_none(self):
def assert_values_are_default(
a, b,
singleton_param=None,
function_param=ascii,
dict_param={'b': [2, 3, 4]},
func_call_param='_'.join(['a', 'b', 'c']),
) -> int:
assert singleton_param is None
assert function_param is ascii
assert dict_param == {'b': [2, 3, 4]}
assert func_call_param == '_'.join(['a', 'b', 'c'])

return 1

func = assert_values_are_default
op = comp.func_to_container_op(func)
self.helper_test_2_in_1_out_component_using_local_call(func, op)


def test_end_to_end_python_component_pipeline_compilation(self):
import kfp.components as comp

Expand Down

0 comments on commit 7917ea4

Please sign in to comment.