Skip to content

Commit

Permalink
SDK - Compiler - Fixed ParallelFor argument resolving (#3029)
Browse files Browse the repository at this point in the history
* SDK - Compiler - Fixed ParallelFor name clashes

The ParallelFor argument reference resolving was really broken.
The logic "worked" like this - of the name of the referenced output
contained the name of the loop collection source output, then it was
considered to be the reference to the loop item.
This broke lots of scenarios especially in cases where there were
multiple components with same output name (e.g. the default "Output"
output name). The logic also did not distinguish between references to
the loop collection item vs. references to the loop collection source
itself.

I've rewritten the argument resolving logic, to fix the issues.

* Argo cannot use {{item}} when withParams items are dicts

* Stabilize the loop template names

* Renamed the test case
  • Loading branch information
Ark-kun authored Feb 11, 2020
1 parent d482698 commit 4a1b282
Show file tree
Hide file tree
Showing 10 changed files with 879 additions and 49 deletions.
32 changes: 19 additions & 13 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def _group_to_dag_template(self, group, inputs, outputs, dependencies):
# i.e., rather than a static list, they are either the output of another task or were input
# as global pipeline parameters

pipeline_param = sub_group.loop_args
pipeline_param = sub_group.loop_args.items_or_pipeline_param
if pipeline_param.op_name is None:
withparam_value = '{{workflow.parameters.%s}}' % pipeline_param.name
else:
Expand Down Expand Up @@ -528,19 +528,25 @@ def get_arguments_for_sub_group(
else:
argument_name = param_name

# default argument_value + special cases
argument_value = '{{inputs.parameters.%s}}' % param_name
# Preparing argument. It can be pipeline input reference, task output reference or loop item (or loop item attribute
sanitized_loop_arg_full_name = '---'
if isinstance(sub_group, dsl.ParallelFor):
if sub_group.loop_args.name in param_name:
if _for_loop.LoopArgumentVariable.name_is_loop_arguments_variable(param_name):
subvar_name = _for_loop.LoopArgumentVariable.get_subvar_name(param_name)
argument_value = '{{item.%s}}' % subvar_name
elif _for_loop.LoopArguments.name_is_loop_arguments(param_name) or sub_group.items_is_pipeline_param:
argument_value = '{{item}}'
else:
raise ValueError("Failed to match loop args with parameter. param_name: {}, ".format(param_name))
elif dependent_name:
argument_value = '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
sanitized_loop_arg_full_name = sanitize_k8s_name(self._pipelineparam_full_name(sub_group.loop_args))
arg_ref_full_name = sanitize_k8s_name(param_name)
# We only care about the reference to the current loop item, not the outer loops
if isinstance(sub_group, dsl.ParallelFor) and arg_ref_full_name.startswith(sanitized_loop_arg_full_name):
if arg_ref_full_name == sanitized_loop_arg_full_name:
argument_value = '{{item}}'
elif _for_loop.LoopArgumentVariable.name_is_loop_arguments_variable(param_name):
subvar_name = _for_loop.LoopArgumentVariable.get_subvar_name(param_name)
argument_value = '{{item.%s}}' % subvar_name
else:
raise ValueError("Argument seems to reference the loop item, but not the item itself and not some attribute of the item. param_name: {}, ".format(param_name))
else:
if dependent_name:
argument_value = '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
else:
argument_value = '{{inputs.parameters.%s}}' % param_name

arguments.append({
'name': argument_name,
Expand Down
26 changes: 20 additions & 6 deletions sdk/python/kfp/dsl/_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class LoopArguments(dsl.PipelineParam):
"""Class representing the arguments that are looped over in a ParallelFor loop in the KFP DSL.
This doesn't need to be instantiated by the end user, rather it will be automatically created by a
ParallelFor ops group."""
LOOP_ITEM_NAME_BASE = 'loop-item'
LOOP_ITEM_PARAM_NAME_BASE = 'loop-item-param'
# number of characters in the code which is passed to the constructor
NUM_CODE_CHARS = 8
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(self, items: Union[ItemList, dsl.PipelineParam], code: Text, name_o
if not self._subvar_name_is_legal(subvar_name):
raise ValueError("Tried to create subvariable named {} but that's not a legal Python variable "
"name.".format(subvar_name))
setattr(self, subvar_name, LoopArgumentVariable(self.name, subvar_name))
setattr(self, subvar_name, LoopArgumentVariable(self.name, subvar_name, loop_args_op_name=self.op_name))

self.items_or_pipeline_param = items
self.referenced_subvar_names = []
Expand All @@ -62,7 +63,7 @@ def from_pipeline_param(cls, param: dsl.PipelineParam) -> 'LoopArguments':
return LoopArguments(
items=param,
code=None,
name_override=param.name,
name_override=param.name + '-' + cls.LOOP_ITEM_NAME_BASE,
op_name=param.op_name,
value=param.value,
)
Expand All @@ -71,7 +72,7 @@ def __getattr__(self, item):
# this is being overridden so that we can access subvariables of the LoopArguments (i.e.: item.a) without
# knowing the subvariable names ahead of time
self.referenced_subvar_names.append(item)
return LoopArgumentVariable(self.name, item)
return LoopArgumentVariable(self.name, item, loop_args_op_name=self.op_name)

def to_list_for_task_yaml(self):
if isinstance(self.items_or_pipeline_param, (list, tuple)):
Expand All @@ -86,20 +87,29 @@ def _make_name(cls, code: Text):
return '{}-{}'.format(cls.LOOP_ITEM_PARAM_NAME_BASE, code)

@classmethod
def name_is_loop_arguments(cls, param_name: Text) -> bool:
def name_is_withitems_loop_argument(cls, param_name: Text) -> bool:
"""Return True if the given parameter name looks like it came from a loop arguments parameter."""
return re.match(
'%s-[0-9a-f]{%s}' % (cls.LOOP_ITEM_PARAM_NAME_BASE, cls.NUM_CODE_CHARS),
param_name,
) is not None

@classmethod
def name_is_withparams_loop_argument(cls, param_name: Text) -> bool:
"""Return True if the given parameter name looks like it came from a withParams loop item."""
return ('-' + cls.LOOP_ITEM_NAME_BASE) in param_name

class LoopArgumentVariable(dsl.PipelineParam):
"""Represents a subvariable for loop arguments. This is used for cases where we're looping over maps,
each of which contains several variables."""
SUBVAR_NAME_DELIMITER = '-subvar-'

def __init__(self, loop_args_name: Text, this_variable_name: Text):
def __init__(
self,
loop_args_name: Text,
this_variable_name: Text,
loop_args_op_name: Text,
):
"""
If the user ran:
with dsl.ParallelFor([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]) as item:
Expand All @@ -111,7 +121,11 @@ def __init__(self, loop_args_name: Text, this_variable_name: Text):
this_variable_name: the name of this subvariable, which is the name of the dict key that spawned
this subvariable.
"""
super().__init__(name=self.get_name(loop_args_name=loop_args_name, this_variable_name=this_variable_name))
super().__init__(
name=self.get_name(loop_args_name=loop_args_name,
this_variable_name=this_variable_name),
op_name=loop_args_op_name,
)

@classmethod
def get_name(cls, loop_args_name: Text, this_variable_name: Text) -> Text:
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/tests/compiler/compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,9 @@ def test_withparam_output_dict(self):
def test_withparam_lightweight_out(self):
self._test_py_compile_yaml('loop_over_lightweight_output')

def test_parallelfor_item_argument_resolving(self):
self._test_py_compile_yaml('parallelfor_item_argument_resolving')

def test_py_input_artifact_raw_value(self):
"""Test pipeline input_artifact_raw_value."""
self._test_py_compile_yaml('input_artifact_raw_value')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
- |-
echo
- |-
{{inputs.parameters.produce-list-data_list}}
{{inputs.parameters.produce-list-data_list-loop-item}}
"image": |-
busybox
"inputs":
"parameters":
- "name": |-
produce-list-data_list
produce-list-data_list-loop-item
"metadata":
"annotations":
"pipelines.kubeflow.org/component_spec": |-
Expand All @@ -54,25 +54,25 @@
- "arguments":
"parameters":
- "name": |-
produce-list-data_list
produce-list-data_list-loop-item
"value": |-
{{inputs.parameters.produce-list-data_list}}
{{inputs.parameters.produce-list-data_list-loop-item}}
"name": |-
consume-data
"template": |-
consume-data
"inputs":
"parameters":
- "name": |-
produce-list-data_list
produce-list-data_list-loop-item
"name": |-
for-loop-for-loop-00000001-1
- "dag":
"tasks":
- "arguments":
"parameters":
- "name": |-
produce-list-data_list
produce-list-data_list-loop-item
"value": |-
{{item}}
"dependencies":
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/usr/bin/env python3
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import NamedTuple

import kfp
from kfp.components import func_to_container_op


# Stabilizing the test output
class StableIDGenerator:
def __init__(self, ):
self._index = 0

def get_next_id(self, ):
self._index += 1
return '{code:0{num_chars:}d}'.format(code=self._index, num_chars=kfp.dsl._for_loop.LoopArguments.NUM_CODE_CHARS)

kfp.dsl.ParallelFor._get_unique_id_code = StableIDGenerator().get_next_id


@func_to_container_op
def produce_str() -> str:
return "Hello"


@func_to_container_op
def produce_list_of_dicts() -> list:
return ([{"aaa": "aaa1", "bbb": "bbb1"}, {"aaa": "aaa2", "bbb": "bbb2"}],)


@func_to_container_op
def produce_list_of_strings() -> list:
return (["a", "z"],)


@func_to_container_op
def produce_list_of_ints() -> list:
return ([1234567890, 987654321],)


@func_to_container_op
def consume(param1):
print(param1)


@kfp.dsl.pipeline()
def parallelfor_item_argument_resolving():
produce_str_task = produce_str()
produce_list_of_strings_task = produce_list_of_strings()
produce_list_of_ints_task = produce_list_of_ints()
produce_list_of_dicts_task = produce_list_of_dicts()

with kfp.dsl.ParallelFor(produce_list_of_strings_task.output) as loop_item:
consume(produce_list_of_strings_task.output)
consume(loop_item)
consume(produce_str_task.output)

with kfp.dsl.ParallelFor(produce_list_of_ints_task.output) as loop_item:
consume(produce_list_of_ints_task.output)
consume(loop_item)

with kfp.dsl.ParallelFor(produce_list_of_dicts_task.output) as loop_item:
consume(produce_list_of_dicts_task.output)
#consume(loop_item) # Cannot use the full loop item when it's a dict
consume(loop_item.aaa)


if __name__ == '__main__':
import kfp.compiler as compiler
compiler.Compiler().compile(parallelfor_item_argument_resolving, __file__ + '.yaml')

Loading

0 comments on commit 4a1b282

Please sign in to comment.