diff --git a/samples/core/loop_parallelism/loop_parallelism.py b/samples/core/loop_parallelism/loop_parallelism.py new file mode 100644 index 00000000000..5a594face2e --- /dev/null +++ b/samples/core/loop_parallelism/loop_parallelism.py @@ -0,0 +1,33 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import kfp.dsl as dsl +import kfp + + +@kfp.components.create_component_from_func +def print_op(s: str): + print(s) + +@dsl.pipeline(name='my-pipeline') +def pipeline2(my_pipe_param=10): + loop_args = [{'A_a': 1, 'B_b': 2}, {'A_a': 10, 'B_b': 20}] + with dsl.ParallelFor(loop_args, parallelism=1) as item: + print_op(item) + print_op(item.A_a) + print_op(item.B_b) + + +if __name__ == '__main__': + kfp.compiler.Compiler().compile(pipeline, __file__ + '.yaml') diff --git a/sdk/python/kfp/compiler/compiler.py b/sdk/python/kfp/compiler/compiler.py index 4e6d2e885c2..6e93847dd42 100644 --- a/sdk/python/kfp/compiler/compiler.py +++ b/sdk/python/kfp/compiler/compiler.py @@ -411,6 +411,8 @@ def _group_to_dag_template(self, group, inputs, outputs, dependencies): inputs, outputs, dependencies are all helper dicts. """ template = {'name': group.name} + if group.parallelism != None: + template["parallelism"] = group.parallelism # Generate inputs section. if inputs.get(group.name, None): diff --git a/sdk/python/kfp/dsl/_ops_group.py b/sdk/python/kfp/dsl/_ops_group.py index 0a28e4123e0..ab173c88804 100644 --- a/sdk/python/kfp/dsl/_ops_group.py +++ b/sdk/python/kfp/dsl/_ops_group.py @@ -28,11 +28,12 @@ class OpsGroup(object): It is useful for implementing a compiler. """ - def __init__(self, group_type: str, name: str=None): + def __init__(self, group_type: str, name: str=None, parallelism: int=None): """Create a new instance of OpsGroup. Args: group_type (str): one of 'pipeline', 'exit_handler', 'condition', 'for_loop', and 'graph'. name (str): name of the opsgroup + parallelism (int): parallelism for the sub-DAG:s """ #TODO: declare the group_type to be strongly typed self.type = group_type @@ -40,6 +41,7 @@ def __init__(self, group_type: str, name: str=None): self.groups = list() self.name = name self.dependencies = [] + self.parallelism = parallelism # recursive_ref points to the opsgroups with the same name if exists. self.recursive_ref = None @@ -181,13 +183,14 @@ class ParallelFor(OpsGroup): def _get_unique_id_code(): return uuid.uuid4().hex[:_for_loop.LoopArguments.NUM_CODE_CHARS] - def __init__(self, loop_args: Union[_for_loop.ItemList, _pipeline_param.PipelineParam]): + def __init__(self, loop_args: Union[_for_loop.ItemList, _pipeline_param.PipelineParam], + parallelism: int=None): self.items_is_pipeline_param = isinstance(loop_args, _pipeline_param.PipelineParam) # use a random code to uniquely identify this loop code = self._get_unique_id_code() group_name = 'for-loop-{}'.format(code) - super().__init__(self.TYPE_NAME, name=group_name) + super().__init__(self.TYPE_NAME, name=group_name, parallelism=parallelism) if self.items_is_pipeline_param: loop_args = _for_loop.LoopArguments.from_pipeline_param(loop_args)