Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK - Compiler - Fixed ParallelFor argument resolving #3029

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
SDK - Compiler - Fixed ParallelFor name clashes
The ParallelFor argument reference resolving was really broken.
The logic "worked" like this - of the name of the referenced output
contained the name of the loop collection source output, then it was
considered to be the reference to the loop item.
This broke lots of scenarios especially in cases where there were
multiple components with same output name (e.g. the default "Output"
output name). The logic also did not distinguish between references to
the loop collection item vs. references to the loop collection source
itself.

I've rewritten the argument resolving logic, to fix the issues.
  • Loading branch information
Ark-kun committed Feb 10, 2020
commit 8bd6620bb63e29cdc5beac34ab3ab2b03511b64d
32 changes: 19 additions & 13 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def _group_to_dag_template(self, group, inputs, outputs, dependencies):
# i.e., rather than a static list, they are either the output of another task or were input
# as global pipeline parameters

pipeline_param = sub_group.loop_args
pipeline_param = sub_group.loop_args.items_or_pipeline_param
if pipeline_param.op_name is None:
withparam_value = '{{workflow.parameters.%s}}' % pipeline_param.name
else:
Expand Down Expand Up @@ -528,19 +528,25 @@ def get_arguments_for_sub_group(
else:
argument_name = param_name

# default argument_value + special cases
argument_value = '{{inputs.parameters.%s}}' % param_name
# Preparing argument. It can be pipeline input reference, task output reference or loop item (or loop item attribute
sanitized_loop_arg_full_name = '---'
if isinstance(sub_group, dsl.ParallelFor):
if sub_group.loop_args.name in param_name:
if _for_loop.LoopArgumentVariable.name_is_loop_arguments_variable(param_name):
subvar_name = _for_loop.LoopArgumentVariable.get_subvar_name(param_name)
argument_value = '{{item.%s}}' % subvar_name
elif _for_loop.LoopArguments.name_is_loop_arguments(param_name) or sub_group.items_is_pipeline_param:
argument_value = '{{item}}'
else:
raise ValueError("Failed to match loop args with parameter. param_name: {}, ".format(param_name))
elif dependent_name:
argument_value = '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
sanitized_loop_arg_full_name = sanitize_k8s_name(self._pipelineparam_full_name(sub_group.loop_args))
arg_ref_full_name = sanitize_k8s_name(param_name)
# We only care about the reference to the current loop item, not the outer loops
if isinstance(sub_group, dsl.ParallelFor) and arg_ref_full_name.startswith(sanitized_loop_arg_full_name):
if arg_ref_full_name == sanitized_loop_arg_full_name:
argument_value = '{{item}}'
elif _for_loop.LoopArgumentVariable.name_is_loop_arguments_variable(param_name):
subvar_name = _for_loop.LoopArgumentVariable.get_subvar_name(param_name)
argument_value = '{{item.%s}}' % subvar_name
else:
raise ValueError("Argument seems to reference the loop item, but not the item itself and not some attribute of the item. param_name: {}, ".format(param_name))
else:
if dependent_name:
argument_value = '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
else:
argument_value = '{{inputs.parameters.%s}}' % param_name

arguments.append({
'name': argument_name,
Expand Down
26 changes: 20 additions & 6 deletions sdk/python/kfp/dsl/_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class LoopArguments(dsl.PipelineParam):
"""Class representing the arguments that are looped over in a ParallelFor loop in the KFP DSL.
This doesn't need to be instantiated by the end user, rather it will be automatically created by a
ParallelFor ops group."""
LOOP_ITEM_NAME_BASE = 'loop-item'
LOOP_ITEM_PARAM_NAME_BASE = 'loop-item-param'
# number of characters in the code which is passed to the constructor
NUM_CODE_CHARS = 8
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(self, items: Union[ItemList, dsl.PipelineParam], code: Text, name_o
if not self._subvar_name_is_legal(subvar_name):
raise ValueError("Tried to create subvariable named {} but that's not a legal Python variable "
"name.".format(subvar_name))
setattr(self, subvar_name, LoopArgumentVariable(self.name, subvar_name))
setattr(self, subvar_name, LoopArgumentVariable(self.name, subvar_name, loop_args_op_name=self.op_name))

self.items_or_pipeline_param = items
self.referenced_subvar_names = []
Expand All @@ -62,7 +63,7 @@ def from_pipeline_param(cls, param: dsl.PipelineParam) -> 'LoopArguments':
return LoopArguments(
items=param,
code=None,
name_override=param.name,
name_override=param.name + '-' + cls.LOOP_ITEM_NAME_BASE,
op_name=param.op_name,
value=param.value,
)
Expand All @@ -71,7 +72,7 @@ def __getattr__(self, item):
# this is being overridden so that we can access subvariables of the LoopArguments (i.e.: item.a) without
# knowing the subvariable names ahead of time
self.referenced_subvar_names.append(item)
return LoopArgumentVariable(self.name, item)
return LoopArgumentVariable(self.name, item, loop_args_op_name=self.op_name)

def to_list_for_task_yaml(self):
if isinstance(self.items_or_pipeline_param, (list, tuple)):
Expand All @@ -86,20 +87,29 @@ def _make_name(cls, code: Text):
return '{}-{}'.format(cls.LOOP_ITEM_PARAM_NAME_BASE, code)

@classmethod
def name_is_loop_arguments(cls, param_name: Text) -> bool:
def name_is_withitems_loop_argument(cls, param_name: Text) -> bool:
"""Return True if the given parameter name looks like it came from a loop arguments parameter."""
return re.match(
'%s-[0-9a-f]{%s}' % (cls.LOOP_ITEM_PARAM_NAME_BASE, cls.NUM_CODE_CHARS),
param_name,
) is not None

@classmethod
def name_is_withparams_loop_argument(cls, param_name: Text) -> bool:
"""Return True if the given parameter name looks like it came from a withParams loop item."""
return ('-' + cls.LOOP_ITEM_NAME_BASE) in param_name

class LoopArgumentVariable(dsl.PipelineParam):
"""Represents a subvariable for loop arguments. This is used for cases where we're looping over maps,
each of which contains several variables."""
SUBVAR_NAME_DELIMITER = '-subvar-'

def __init__(self, loop_args_name: Text, this_variable_name: Text):
def __init__(
self,
loop_args_name: Text,
this_variable_name: Text,
loop_args_op_name: Text,
):
"""
If the user ran:
with dsl.ParallelFor([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]) as item:
Expand All @@ -111,7 +121,11 @@ def __init__(self, loop_args_name: Text, this_variable_name: Text):
this_variable_name: the name of this subvariable, which is the name of the dict key that spawned
this subvariable.
"""
super().__init__(name=self.get_name(loop_args_name=loop_args_name, this_variable_name=this_variable_name))
super().__init__(
name=self.get_name(loop_args_name=loop_args_name,
this_variable_name=this_variable_name),
op_name=loop_args_op_name,
)

@classmethod
def get_name(cls, loop_args_name: Text, this_variable_name: Text) -> Text:
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/tests/compiler/compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,9 @@ def test_withparam_output_dict(self):
def test_withparam_lightweight_out(self):
self._test_py_compile_yaml('loop_over_lightweight_output')

def test_parallelfor_name_clashes(self):
self._test_py_compile_yaml('parallelfor_name_clashes')

def test_py_input_artifact_raw_value(self):
"""Test pipeline input_artifact_raw_value."""
self._test_py_compile_yaml('input_artifact_raw_value')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
- |-
echo
- |-
{{inputs.parameters.produce-list-data_list}}
{{inputs.parameters.produce-list-data_list-loop-item}}
"image": |-
busybox
"inputs":
"parameters":
- "name": |-
produce-list-data_list
produce-list-data_list-loop-item
"metadata":
"annotations":
"pipelines.kubeflow.org/component_spec": |-
Expand All @@ -54,25 +54,25 @@
- "arguments":
"parameters":
- "name": |-
produce-list-data_list
produce-list-data_list-loop-item
"value": |-
{{inputs.parameters.produce-list-data_list}}
{{inputs.parameters.produce-list-data_list-loop-item}}
"name": |-
consume-data
"template": |-
consume-data
"inputs":
"parameters":
- "name": |-
produce-list-data_list
produce-list-data_list-loop-item
"name": |-
for-loop-for-loop-00000001-1
- "dag":
"tasks":
- "arguments":
"parameters":
- "name": |-
produce-list-data_list
produce-list-data_list-loop-item
"value": |-
{{item}}
"dependencies":
Expand Down
47 changes: 47 additions & 0 deletions sdk/python/tests/compiler/testdata/parallelfor_name_clashes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import NamedTuple

import kfp
from kfp.components import func_to_container_op

@func_to_container_op
def produce_str() -> str:
return "Hello"

@func_to_container_op
def produce_list() -> list:
return ["1", "2"]

@func_to_container_op
def consume(param1):
print(param1)

@kfp.dsl.pipeline()
def parallelfor_name_clashes_pipeline():
produce_str_task = produce_str()
produce_list_task = produce_list()
with kfp.dsl.ParallelFor(produce_list_task.output) as loop_item:
consume(produce_list_task.output)
consume(produce_str_task.output)
consume(loop_item)
consume(loop_item.aaa)


if __name__ == '__main__':
import kfp.compiler as compiler
compiler.Compiler().compile(parallelfor_name_clashes_pipeline, __file__ + '.yaml')

Loading