Skip to content

Commit

Permalink
SDK/DSL/Compiler - Reverted fix of dsl.Condition until the UI is read…
Browse files Browse the repository at this point in the history
…y. (#94)
  • Loading branch information
Ark-kun authored and k8s-ci-robot committed Nov 6, 2018
1 parent 8315f39 commit 7e45693
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 56 deletions.
130 changes: 80 additions & 50 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,17 +289,14 @@ def _get_dependencies(self, pipeline, root_group, op_groups):
dependencies[downstream_groups[0]].add(upstream_groups[0])
return dependencies

def _resolve_value_or_reference(self, value_or_reference, inputs):
if isinstance(value_or_reference, dsl.PipelineParam):
parameter_name = self._param_full_name(value_or_reference)
task_names = [task_name for param_name, task_name in inputs if param_name == parameter_name]
if task_names:
task_name = task_names[0]
return '{{tasks.%s.outputs.parameters.%s}}' % (task_name, parameter_name)
else:
return '{{inputs.parameters.%s}}' % parameter_name
else:
return str(value_or_reference)
def _create_condition(self, condition):
left = ('{{inputs.parameters.%s}}' % self._param_full_name(condition.operand1)
if isinstance(condition.operand1, dsl.PipelineParam)
else str(condition.operand1))
right = ('{{inputs.parameters.%s}}' % self._param_full_name(condition.operand2)
if isinstance(condition.operand2, dsl.PipelineParam)
else str(condition.operand2))
return ('%s == %s' % (left, right))

def _group_to_template(self, group, inputs, outputs, dependencies):
"""Generate template given an OpsGroup.
Expand Down Expand Up @@ -329,56 +326,89 @@ def _group_to_template(self, group, inputs, outputs, dependencies):
template_outputs.sort(key=lambda x: x['name'])
template['outputs'] = {'parameters': template_outputs}

# Generate tasks section.
tasks = []
for sub_group in group.groups + group.ops:
task = {
'name': sub_group.name,
'template': sub_group.name,
if group.type == 'condition':
# This is a workaround for the fact that argo does not support conditions in DAG mode.
# Basically, we insert an extra group that contains only the original group. The extra group
# operates in "step" mode where condition is supported.
only_child = group.groups[0]
step = {
'name': only_child.name,
'template': only_child.name,
}

if isinstance(sub_group, dsl.OpsGroup) and sub_group.type == 'condition':
subgroup_inputs = inputs.get(sub_group.name, [])
condition = sub_group.condition
condition_operation = '=='
operand1_value = self._resolve_value_or_reference(condition.operand1, subgroup_inputs)
operand2_value = self._resolve_value_or_reference(condition.operand2, subgroup_inputs)
task['when'] = '{} {} {}'.format(operand1_value, condition_operation, operand2_value)

# Generate dependencies section for this task.
if dependencies.get(sub_group.name, None):
group_dependencies = list(dependencies[sub_group.name])
group_dependencies.sort()
task['dependencies'] = group_dependencies

# Generate arguments section for this task.
if inputs.get(sub_group.name, None):
if inputs.get(only_child.name, None):
arguments = []
for param_name, dependent_name in inputs[sub_group.name]:
if dependent_name:
# The value comes from an upstream sibling.
arguments.append({
'name': param_name,
'value': '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
})
else:
# The value comes from its parent.
arguments.append({
for param_name, dependent_name in inputs[only_child.name]:
arguments.append({
'name': param_name,
'value': '{{inputs.parameters.%s}}' % param_name
})
})
arguments.sort(key=lambda x: x['name'])
task['arguments'] = {'parameters': arguments}
tasks.append(task)
tasks.sort(key=lambda x: x['name'])
template['dag'] = {'tasks': tasks}
step['arguments'] = {'parameters': arguments}
step['when'] = self._create_condition(group.condition)
template['steps'] = [[step]]
else:
# Generate tasks section.
tasks = []
for sub_group in group.groups + group.ops:
task = {
'name': sub_group.name,
'template': sub_group.name,
}
# Generate dependencies section for this task.
if dependencies.get(sub_group.name, None):
group_dependencies = list(dependencies[sub_group.name])
group_dependencies.sort()
task['dependencies'] = group_dependencies

# Generate arguments section for this task.
if inputs.get(sub_group.name, None):
arguments = []
for param_name, dependent_name in inputs[sub_group.name]:
if dependent_name:
# The value comes from an upstream sibling.
arguments.append({
'name': param_name,
'value': '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
})
else:
# The value comes from its parent.
arguments.append({
'name': param_name,
'value': '{{inputs.parameters.%s}}' % param_name
})
arguments.sort(key=lambda x: x['name'])
task['arguments'] = {'parameters': arguments}
tasks.append(task)
tasks.sort(key=lambda x: x['name'])
template['dag'] = {'tasks': tasks}
return template

def _create_new_groups(self, root_group):
"""Create a copy of the input group, and insert extra groups for conditions."""

new_group = copy.deepcopy(root_group)

def _insert_group_for_condition_helper(group):
for i, g in enumerate(group.groups):
if g.type == 'condition':
child_condition_group = dsl.OpsGroup('condition-child', g.name + '-child')
child_condition_group.ops = g.ops
child_condition_group.groups = g.groups
g.groups = [child_condition_group]
g.ops = list()
_insert_group_for_condition_helper(child_condition_group)
else:
_insert_group_for_condition_helper(g)

_insert_group_for_condition_helper(new_group)
return new_group

def _create_templates(self, pipeline):
"""Create all groups and ops templates in the pipeline."""

new_root_group = pipeline.groups[0]
# This is needed only because Argo does not support condition in DAG mode.
# Revisit when https://github.com/argoproj/argo/issues/921 is fixed.
new_root_group = self._create_new_groups(pipeline.groups[0])

op_groups = self._get_groups_for_ops(new_root_group)
inputs, outputs = self._get_inputs_outputs(pipeline, new_root_group, op_groups)
Expand Down
48 changes: 42 additions & 6 deletions sdk/python/tests/compiler/testdata/coin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ spec:
entrypoint: pipeline-flip-coin
serviceAccountName: pipeline-runner
templates:
- inputs:
parameters:
- name: flip-output
name: condition-1
steps:
- - arguments:
parameters:
- name: flip-output
value: '{{inputs.parameters.flip-output}}'
name: condition-1-child
template: condition-1-child
when: '{{inputs.parameters.flip-output}} == heads'
- dag:
tasks:
- arguments:
Expand All @@ -33,7 +45,6 @@ spec:
- flip-again
name: condition-2
template: condition-2
when: '{{tasks.flip-again.outputs.parameters.flip-again-output}} == tails'
- arguments:
parameters:
- name: flip-output
Expand All @@ -43,7 +54,22 @@ spec:
inputs:
parameters:
- name: flip-output
name: condition-1
name: condition-1-child
- inputs:
parameters:
- name: flip-again-output
- name: flip-output
name: condition-2
steps:
- - arguments:
parameters:
- name: flip-again-output
value: '{{inputs.parameters.flip-again-output}}'
- name: flip-output
value: '{{inputs.parameters.flip-output}}'
name: condition-2-child
template: condition-2-child
when: '{{inputs.parameters.flip-again-output}} == tails'
- dag:
tasks:
- arguments:
Expand All @@ -58,7 +84,19 @@ spec:
parameters:
- name: flip-again-output
- name: flip-output
name: condition-2
name: condition-2-child
- inputs:
parameters:
- name: flip-output
name: condition-3
steps:
- - arguments:
parameters:
- name: flip-output
value: '{{inputs.parameters.flip-output}}'
name: condition-3-child
template: condition-3-child
when: '{{inputs.parameters.flip-output}} == tails'
- dag:
tasks:
- arguments:
Expand All @@ -70,7 +108,7 @@ spec:
inputs:
parameters:
- name: flip-output
name: condition-3
name: condition-3-child
- container:
args:
- python -c "import random; result = 'heads' if random.randint(0,1) == 0 else
Expand Down Expand Up @@ -163,7 +201,6 @@ spec:
- flip
name: condition-1
template: condition-1
when: '{{tasks.flip.outputs.parameters.flip-output}} == heads'
- arguments:
parameters:
- name: flip-output
Expand All @@ -172,7 +209,6 @@ spec:
- flip
name: condition-3
template: condition-3
when: '{{tasks.flip.outputs.parameters.flip-output}} == tails'
- name: flip
template: flip
name: pipeline-flip-coin
Expand Down

0 comments on commit 7e45693

Please sign in to comment.