diff --git a/sdk/python/kfp/components/_components.py b/sdk/python/kfp/components/_components.py index bffce88fe34..04338a688ae 100644 --- a/sdk/python/kfp/components/_components.py +++ b/sdk/python/kfp/components/_components.py @@ -193,6 +193,14 @@ def _try_get_object_by_name(obj_name): _created_task_transformation_handler.append(_dsl_bridge.create_container_op_from_task) +class _DefaultValue: + def __init__(self, value): + self.value = value + + def __repr__(self): + return repr(self.value) + + #TODO: Refactor the function to make it shorter def _create_task_factory_from_component_spec(component_spec:ComponentSpec, component_filename=None, component_ref: ComponentReference = None): name = component_spec.name or _default_component_name @@ -220,7 +228,7 @@ def create_task_from_component_and_arguments(pythonic_arguments): arguments = { pythonic_name_to_input_name[k]: (v if isinstance(v, valid_argument_types) else str(v)) for k, v in pythonic_arguments.items() - if v is not None + if not isinstance(v, _DefaultValue) # Skipping passing arguments for optional values that have not been overridden. } for key in arguments: if isinstance(arguments[key], PipelineParam): @@ -244,7 +252,22 @@ def create_task_from_component_and_arguments(pythonic_arguments): #Reordering the inputs since in Python optional parameters must come after required parameters reordered_input_list = [input for input in inputs_list if input.default is None and not input.optional] + [input for input in inputs_list if not (input.default is None and not input.optional)] - input_parameters = [_dynamic.KwParameter(input_name_to_pythonic[port.name], annotation=(_try_get_object_by_name(str(port.type)) if port.type else inspect.Parameter.empty), default=port.default if port.default is not None else (None if port.optional else inspect.Parameter.empty)) for port in reordered_input_list] + + def component_default_to_func_default(component_default: str, is_optional: bool): + if is_optional: + return _DefaultValue(component_default) + if component_default is not None: + return component_default + return inspect.Parameter.empty + + input_parameters = [ + _dynamic.KwParameter( + input_name_to_pythonic[port.name], + annotation=(_try_get_object_by_name(str(port.type)) if port.type else inspect.Parameter.empty), + default=component_default_to_func_default(port.default, port.optional), + ) + for port in reordered_input_list + ] factory_function_parameters = input_parameters #Outputs are no longer part of the task factory function signature. The paths are always generated by the system. return _dynamic.create_function_from_parameters( diff --git a/sdk/python/kfp/components/_python_op.py b/sdk/python/kfp/components/_python_op.py index 3710a6c6720..e73bbb52fa1 100644 --- a/sdk/python/kfp/components/_python_op.py +++ b/sdk/python/kfp/components/_python_op.py @@ -158,10 +158,14 @@ def annotation_to_type_struct(annotation): type=type_struct, ) if parameter.default is not inspect.Parameter.empty: - if parameter.default is None: - input_spec.optional = True - else: - input_spec.default = str(parameter.default) + input_spec.optional = True + if parameter.default is not None: + serialized_default = str(parameter.default) + if not isinstance(parameter.default, (str, int, float)): + import warnings + warnings.warn('Default value of unsupported type {} will be converted to string "{}".'.format(str(type(parameter.default)), serialized_default)) + input_spec.default = serialized_default + inputs.append(input_spec) #Analyzing the return type annotations. @@ -241,6 +245,7 @@ def _func_to_component_spec(func, extra_code='', base_image=_default_base_image, arg_parse_code_lines = [ 'import argparse', + '_missing_arg = object()', '_parser = argparse.ArgumentParser(prog={prog_repr}, description={description_repr})'.format( prog_repr=repr(component_spec.name or ''), description_repr=repr(component_spec.description or ''), @@ -249,13 +254,12 @@ def _func_to_component_spec(func, extra_code='', base_image=_default_base_image, arguments = [] for input in component_spec.inputs: param_flag = "--" + input.name.replace("_", "-") - is_required = not input.optional #TODO: Make all parameters with default values optional in argparse so that the complex defaults can be preserved. - line = '_parser.add_argument("{param_flag}", dest="{param_var}", type={param_type}, required={is_required}, default={default_repr})'.format( + is_required = not input.optional + line = '_parser.add_argument("{param_flag}", dest="{param_var}", type={param_type}, required={is_required}, default=_missing_arg)'.format( param_flag=param_flag, param_var=input.name, param_type=(input.type if input.type in ['int', 'float', 'bool'] else 'str'), is_required=str(is_required), - default_repr=repr(str(input.default)) if input.default is not None else None, ) arg_parse_code_lines.append(line) if is_required: @@ -284,7 +288,7 @@ def _func_to_component_spec(func, extra_code='', base_image=_default_base_image, arguments.extend(OutputPathPlaceholder(output.name) for output in component_spec.outputs) arg_parse_code_lines.extend([ - '_parsed_args = vars(_parser.parse_args())', + '_parsed_args = {k: v for k, v in vars(_parser.parse_args()).items() if v is not _missing_arg}', ]) arg_parse_code_lines.extend([ diff --git a/sdk/python/tests/components/test_python_op.py b/sdk/python/tests/components/test_python_op.py index 6df8a003b91..5a90df81663 100644 --- a/sdk/python/tests/components/test_python_op.py +++ b/sdk/python/tests/components/test_python_op.py @@ -279,6 +279,27 @@ def assert_is_none(a, b, arg=None) -> int: op = comp.func_to_container_op(func, output_component_file='comp.yaml') self.helper_test_2_in_1_out_component_using_local_call(func, op) + + def test_handling_complex_default_values_of_none(self): + def assert_values_are_default( + a, b, + singleton_param=None, + function_param=ascii, + dict_param={'b': [2, 3, 4]}, + func_call_param='_'.join(['a', 'b', 'c']), + ) -> int: + assert singleton_param is None + assert function_param is ascii + assert dict_param == {'b': [2, 3, 4]} + assert func_call_param == '_'.join(['a', 'b', 'c']) + + return 1 + + func = assert_values_are_default + op = comp.func_to_container_op(func) + self.helper_test_2_in_1_out_component_using_local_call(func, op) + + def test_end_to_end_python_component_pipeline_compilation(self): import kfp.components as comp