Skip to content

Commit

Permalink
fix recursion bug (#1583)
Browse files Browse the repository at this point in the history
* fix recursion bug

* propagate inputs to out layers of opsgroup; adjust unit tests
  • Loading branch information
gaoning777 authored and k8s-ci-robot committed Jul 1, 2019
1 parent a42d5fb commit f23b619
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
9 changes: 3 additions & 6 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def _get_inputs_outputs(self, pipeline, root_group, op_groups, opsgroup_groups,
if not op.is_exit_handler:
for g in op_groups[op.name]:
inputs[g].add((full_name, None))

# Generate the input/output for recursive opsgroups
# It propagates the recursive opsgroups IO to their ancester opsgroups
def _get_inputs_outputs_recursive_opsgroup(group):
Expand Down Expand Up @@ -256,13 +255,11 @@ def _get_inputs_outputs_recursive_opsgroup(group):
outputs[g].add((full_name, None))
else:
outputs[g].add((full_name, upstream_groups[i+1]))
else:
if not op.is_exit_handler:
for g in op_groups[op.name]:
inputs[g].add((full_name, None))
elif not is_condition_param:
for g in op_groups[group.name]:
inputs[g].add((full_name, None))
for subgroup in group.groups:
_get_inputs_outputs_recursive_opsgroup(subgroup)

_get_inputs_outputs_recursive_opsgroup(root_group)
return inputs, outputs

Expand Down
8 changes: 4 additions & 4 deletions sdk/python/tests/compiler/testdata/recursive_while.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,20 @@ def __init__(self, msg):
)

@dsl._component.graph_component
def flip_component(flip_result):
def flip_component(flip_result, maxVal):
with dsl.Condition(flip_result == 'heads'):
print_flip = PrintOp(flip_result)
flipA = FlipCoinOp().after(print_flip)
flip_component(flipA.output)
flip_component(flipA.output, maxVal)

@dsl.pipeline(
name='pipeline flip coin',
description='shows how to use dsl.Condition.'
)
def flipcoin():
def flipcoin(maxVal=12):
flipA = FlipCoinOp()
flipB = FlipCoinOp()
flip_loop = flip_component(flipA.output)
flip_loop = flip_component(flipA.output, maxVal)
flip_loop.after(flipB)
PrintOp('cool, it is over. %s' % flipA.output).after(flip_loop)

Expand Down
35 changes: 24 additions & 11 deletions sdk/python/tests/compiler/testdata/recursive_while.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ metadata:
generateName: pipeline-flip-coin-
spec:
arguments:
parameters: []
parameters:
- name: maxval
value: '12'
entrypoint: pipeline-flip-coin
serviceAccountName: pipeline-runner
templates:
Expand All @@ -22,6 +24,8 @@ spec:
parameters:
- name: flip-output
value: '{{tasks.flip-3.outputs.parameters.flip-3-output}}'
- name: maxval
value: '{{inputs.parameters.maxval}}'
dependencies:
- flip-3
name: graph-flip-component-1
Expand All @@ -35,6 +39,7 @@ spec:
inputs:
parameters:
- name: flip-output
- name: maxval
name: condition-2
- container:
args:
Expand All @@ -48,11 +53,11 @@ spec:
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
optional: true
path: /mlpipeline-ui-metadata.json
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
optional: true
path: /mlpipeline-metrics.json
parameters:
- name: flip-output
valueFrom:
Expand All @@ -69,11 +74,11 @@ spec:
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
optional: true
path: /mlpipeline-ui-metadata.json
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
optional: true
path: /mlpipeline-metrics.json
parameters:
- name: flip-2-output
valueFrom:
Expand All @@ -90,11 +95,11 @@ spec:
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
optional: true
path: /mlpipeline-ui-metadata.json
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
optional: true
path: /mlpipeline-metrics.json
parameters:
- name: flip-3-output
valueFrom:
Expand All @@ -105,12 +110,15 @@ spec:
parameters:
- name: flip-output
value: '{{inputs.parameters.flip-output}}'
- name: maxval
value: '{{inputs.parameters.maxval}}'
name: condition-2
template: condition-2
when: '{{inputs.parameters.flip-output}} == heads'
inputs:
parameters:
- name: flip-output
- name: maxval
name: graph-flip-component-1
- dag:
tasks:
Expand All @@ -122,6 +130,8 @@ spec:
parameters:
- name: flip-output
value: '{{tasks.flip.outputs.parameters.flip-output}}'
- name: maxval
value: '{{inputs.parameters.maxval}}'
dependencies:
- flip
- flip-2
Expand All @@ -136,6 +146,9 @@ spec:
- graph-flip-component-1
name: print-2
template: print-2
inputs:
parameters:
- name: maxval
name: pipeline-flip-coin
- container:
command:
Expand All @@ -149,11 +162,11 @@ spec:
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
optional: true
path: /mlpipeline-ui-metadata.json
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
optional: true
path: /mlpipeline-metrics.json
- container:
command:
- echo
Expand All @@ -166,8 +179,8 @@ spec:
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
optional: true
path: /mlpipeline-ui-metadata.json
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
optional: true
path: /mlpipeline-metrics.json

0 comments on commit f23b619

Please sign in to comment.