From 9110296e57ca05a157788e9542534315b86bef2c Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Fri, 30 Nov 2018 13:30:09 -0800 Subject: [PATCH] SDK/Components - Support for optional inputs (#214) * Renamed "required" to "optional" * Added support for optional inputs * Added tests for optional inputs. "If then *" tests now also work. --- sdk/python/kfp/components/_components.py | 34 +++++-- sdk/python/kfp/components/_structures.py | 15 ++- .../tests/components/test_components.py | 92 +++++++++++++++++-- 3 files changed, 119 insertions(+), 22 deletions(-) diff --git a/sdk/python/kfp/components/_components.py b/sdk/python/kfp/components/_components.py index e776815b89e..8f08a69c3bc 100644 --- a/sdk/python/kfp/components/_components.py +++ b/sdk/python/kfp/components/_components.py @@ -188,6 +188,8 @@ def _create_task_factory_from_component_spec(component_spec:ComponentSpec, compo inputs_list = component_spec.inputs or [] #List[InputSpec] outputs_list = component_spec.outputs or [] #List[OutputSpec] + inputs_dict = {port.name: port for port in inputs_list} + input_name_to_pythonic = {} output_name_to_pythonic = {} pythonic_name_to_original = {} @@ -246,7 +248,16 @@ def expand_command_part(arg): #input values with original names assert isinstance(func_argument, str) port_name = func_argument input_value = pythonic_input_argument_values[input_name_to_pythonic[port_name]] - return str(input_value) + if input_value is not None: + return str(input_value) + else: + input_spec = inputs_dict[port_name] + if input_spec.optional: + #Even when we support default values there is no need to check for a default here. + #In current execution flow (called by python task factory), the missing argument would be replaced with the default value by python itself. + return None + else: + raise ValueError('No value provided for input {}'.format(port_name)) elif func_name == 'file': assert isinstance(func_argument, str) @@ -254,8 +265,17 @@ def expand_command_part(arg): #input values with original names input_filename = _generate_input_file_name(port_name) input_key = input_name_to_kubernetes[port_name] input_value = pythonic_input_argument_values[input_name_to_pythonic[port_name]] - file_inputs[input_key] = {'local_path': input_filename, 'data_source': input_value} - return input_filename + if input_value is not None: + file_inputs[input_key] = {'local_path': input_filename, 'data_source': input_value} + return input_filename + else: + input_spec = inputs_dict[port_name] + if input_spec.optional: + #Even when we support default values there is no need to check for a default here. + #In current execution flow (called by python task factory), the missing argument would be replaced with the default value by python itself. + return None + else: + raise ValueError('No value provided for input {}'.format(port_name)) elif func_name == 'output': assert isinstance(func_argument, str) @@ -340,11 +360,13 @@ def expand_argument_list(argument_list): import inspect from . import _dynamic - - #Still allowing to set the output parameters, but make them optional and auto-generate if missing. - input_parameters = [_dynamic.KwParameter(input_name_to_pythonic[port.name], annotation=(_try_get_object_by_name(port.type) if port.type else inspect.Parameter.empty)) for port in inputs_list] + + #Reordering the inputs since in Python optional parameters must come after reuired parameters + reordered_input_list = [input for input in inputs_list if not input.optional] + [input for input in inputs_list if input.optional] + input_parameters = [_dynamic.KwParameter(input_name_to_pythonic[port.name], annotation=(_try_get_object_by_name(port.type) if port.type else inspect.Parameter.empty), default=(None if port.optional else inspect.Parameter.empty)) for port in reordered_input_list] output_parameters = [_dynamic.KwParameter(output_name_to_pythonic[port.name], annotation=('OutputFile[{}]'.format(port.type) if port.type else inspect.Parameter.empty), default=None) for port in outputs_list] + #Still allowing to set the output parameters, but make them optional and auto-generate if missing. factory_function_parameters = input_parameters + output_parameters return _dynamic.create_function_from_parameters( diff --git a/sdk/python/kfp/components/_structures.py b/sdk/python/kfp/components/_structures.py index f053fb5023a..2c6f420bbb5 100644 --- a/sdk/python/kfp/components/_structures.py +++ b/sdk/python/kfp/components/_structures.py @@ -34,13 +34,13 @@ class InputOrOutputSpec: - def __init__(self, name:str, type:str=None, description:str=None, required:bool=True, pattern:str=None): + def __init__(self, name:str, type:str=None, description:str=None, optional:bool=False, pattern:str=None): if not isinstance(name, str): raise ValueError('name must be a string') self.name = name self.type = type self.description = description - self.required = required + self.optional = optional self.pattern = pattern @classmethod @@ -73,12 +73,12 @@ def from_struct(cls, struct:Union[Tuple[str, Mapping],Mapping[str,Mapping],str]) if 'description' in spec_dict: port_spec.description = str(spec_dict.pop('description')) - if 'required' in spec_dict: - port_spec.required = bool(spec_dict.pop('required')) + if 'optional' in spec_dict: + port_spec.optional = bool(spec_dict.pop('optional')) if 'pattern' in spec_dict: port_spec.pattern = str(spec_dict.pop('pattern')) - + if spec_dict: raise ValueError('Found unrecognized properties: {}'.format(spec_dict)) @@ -91,9 +91,8 @@ def to_struct(self): struct['type'] = self.type if self.description: struct['description'] = self.description - if self.required != True: #Only outputting when not default - print(self.required) - struct['required'] = self.required + if self.optional: + struct['optional'] = self.optional if self.pattern: struct['pattern'] = self.pattern diff --git a/sdk/python/tests/components/test_components.py b/sdk/python/tests/components/test_components.py index a7bf943cc58..a79f35dabd2 100644 --- a/sdk/python/tests/components/test_components.py +++ b/sdk/python/tests/components/test_components.py @@ -333,6 +333,61 @@ def test_automatic_output_resolving(self): self.assertEqual(len(task1.arguments), 2) + def test_optional_inputs_reordering(self): + '''Tests optional input reordering. + In python signature, optional arguments must come after the required arguments. + ''' + component_text = '''\ +inputs: +- {name: in1} +- {name: in2, optional: true} +- {name: in3} +implementation: + dockerContainer: + image: busybox +''' + task_factory1 = comp.load_component_from_text(component_text) + import inspect + signature = inspect.signature(task_factory1) + actual_signature = list(signature.parameters.keys()) + self.assertSequenceEqual(actual_signature, ['in1', 'in3', 'in2'], str) + + def test_missing_optional_input_value_argument(self): + '''Missing optional inputs should resolve to nothing''' + component_text = '''\ +inputs: +- {name: input 1, optional: true} +implementation: + dockerContainer: + image: busybox + command: + - a + - {value: input 1} + - z +''' + task_factory1 = comp.load_component_from_text(component_text) + task1 = task_factory1() + + self.assertEqual(task1.command, ['a', 'z']) + + def test_missing_optional_input_file_argument(self): + '''Missing optional inputs should resolve to nothing''' + component_text = '''\ +inputs: +- {name: input 1, optional: true} +implementation: + dockerContainer: + image: busybox + command: + - a + - {file: input 1} + - z +''' + task_factory1 = comp.load_component_from_text(component_text) + task1 = task_factory1() + + self.assertEqual(task1.command, ['a', 'z']) + def test_command_concat(self): component_text = '''\ inputs: @@ -413,7 +468,7 @@ def test_command_if_false_string_then_else(self): def test_command_if_is_present_then(self): component_text = '''\ inputs: -- {name: In, required: false} +- {name: In, optional: true} implementation: container: image: busybox @@ -428,14 +483,13 @@ def test_command_if_is_present_then(self): task_then = task_factory1('data') self.assertEqual(task_then.arguments, ['--in', 'data']) - #TODO: Fix optional arguments - #task_else = task_factory1() #Error: TypeError: Component() missing 1 required positional argument: 'in' - #self.assertEqual(task_else.arguments, []) + task_else = task_factory1() + self.assertEqual(task_else.arguments, []) def test_command_if_is_present_then_else(self): component_text = '''\ inputs: -- {name: In, required: false} +- {name: In, optional: true} implementation: container: image: busybox @@ -450,9 +504,31 @@ def test_command_if_is_present_then_else(self): task_then = task_factory1('data') self.assertEqual(task_then.arguments, ['--in', 'data']) - #TODO: Fix optional arguments - #task_else = task_factory1() #Error: TypeError: Component() missing 1 required positional argument: 'in' - #self.assertEqual(task_else.arguments, ['--no-in']) + task_else = task_factory1() + self.assertEqual(task_else.arguments, ['--no-in']) + + + def test_command_if_input_value_then(self): + component_text = '''\ +inputs: +- {name: Do test, type: boolean, optional: true} +- {name: Test data, optional: true} +- {name: Test parameter 1, optional: true} +implementation: + dockerContainer: + image: busybox + arguments: + - if: + cond: {value: Do test} + then: [--test-data, {value: Test data}, --test-param1, {value: Test parameter 1}] +''' + task_factory1 = comp.load_component(text=component_text) + + task_then = task_factory1(True, 'test_data.txt', 42) + self.assertEqual(task_then.arguments, ['--test-data', 'test_data.txt', '--test-param1', '42']) + + task_else = task_factory1() + self.assertEqual(task_else.arguments, []) if __name__ == '__main__':