Skip to content

Commit

Permalink
fix(sdk): Compiler - Fixed the input argument mapping when using dsl.…
Browse files Browse the repository at this point in the history
…graph_component. Fixes kubeflow#3915 (4082)

* SDK - Compiler - Fixed the input argument mapping when using dsl.graph_component

Fixes kubeflow#3915

* Stopped relying on the argument order at all

This can make the compilation less fragile.
  • Loading branch information
Ark-kun authored Jun 29, 2020
1 parent d24eb78 commit 6960366
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 3 deletions.
4 changes: 2 additions & 2 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,10 @@ def get_arguments_for_sub_group(
arguments = []
for param_name, dependent_name in inputs[sub_group.name]:
if is_recursive_subgroup:
for index, input in enumerate(sub_group.inputs):
for input_name, input in sub_group.arguments.items():
if param_name == self._pipelineparam_full_name(input):
break
referenced_input = sub_group.recursive_ref.inputs[index]
referenced_input = sub_group.recursive_ref.arguments[input_name]
argument_name = self._pipelineparam_full_name(referenced_input)
else:
argument_name = param_name
Expand Down
7 changes: 6 additions & 1 deletion sdk/python/kfp/dsl/_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from deprecated.sphinx import deprecated
from ._pipeline_param import PipelineParam
from .types import check_types, InconsistentTypeException
Expand Down Expand Up @@ -119,8 +120,12 @@ def flip_component(flip_result):
from functools import wraps
@wraps(func)
def _graph_component(*args, **kargs):
# We need to make sure that the arguments are correctly mapped to inputs regardless of the passing order
signature = inspect.signature(func)
bound_arguments = signature.bind(*args, **kargs)
graph_ops_group = Graph(func.__name__)
graph_ops_group.inputs = list(args) + list(kargs.values())
graph_ops_group.inputs = list(bound_arguments.arguments.values())
graph_ops_group.arguments = bound_arguments.arguments
for input in graph_ops_group.inputs:
if not isinstance(input, PipelineParam):
raise ValueError('arguments to ' + func.__name__ + ' should be PipelineParams.')
Expand Down
44 changes: 44 additions & 0 deletions sdk/python/tests/compiler/compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,3 +929,47 @@ def some_pipeline():

def test_artifact_passing_using_volume(self):
self._test_py_compile_yaml('artifact_passing_using_volume')

def test_recursive_argument_mapping(self):
# Verifying that the recursive call arguments are passed correctly when specified out of order
component_2_in_0_out_op = kfp.components.load_component_from_text('''
inputs:
- name: in1
- name: in2
implementation:
container:
image: busybox
command:
- echo
- inputValue: in1
- inputValue: in2
''')

@dsl.graph_component
def subgraph(graph_in1, graph_in2):
component_2_in_0_out_op(
in1=graph_in1,
in2=graph_in2,
)
subgraph(
# Wrong order!
graph_in2=graph_in2,
graph_in1=graph_in1,
)
def some_pipeline(pipeline_in1, pipeline_in2):
subgraph(pipeline_in1, pipeline_in2)

workflow_dict = kfp.compiler.Compiler()._compile(some_pipeline)
subgraph_template = [template for template in workflow_dict['spec']['templates'] if 'subgraph' in template['name']][0]
recursive_subgraph_task = [task for task in subgraph_template['dag']['tasks'] if 'subgraph' in task['name']][0]
for argument in recursive_subgraph_task['arguments']['parameters']:
if argument['name'].endswith('in1'):
self.assertTrue(
argument['value'].endswith('in1}}'),
'Wrong argument mapping: "{}" passed to "{}"'.format(argument['value'], argument['name']))
elif argument['name'].endswith('in2'):
self.assertTrue(
argument['value'].endswith('in2}}'),
'Wrong argument mapping: "{}" passed to "{}"'.format(argument['value'], argument['name']))
else:
self.fail('Unexpected input name: ' + argument['name'])

0 comments on commit 6960366

Please sign in to comment.